[AIRFLOW-492] Make sure stat updates cannot fail a task

Previously a failed commit into the db for the statistics
could also fail a task. Secondly, the ui could display
out of date statistics.

This patch reworks DagStat so that failure to update the
statistics does not propagate. Next to that, it make sure
the ui always displays the latest statistics.

Closes #2254 from bolkedebruin/AIRFLOW-492

(cherry picked from commit c2472ffa124ffc65b8762ea583554494624dbb6a)


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/e342d0d2
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/e342d0d2
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/e342d0d2

Branch: refs/heads/v1-8-stable
Commit: e342d0d223e47ea25f73baaa00a16df414a6e0df
Parents: 5800f56
Author: Bolke de Bruin <bo...@xs4all.nl>
Authored: Wed Apr 26 20:39:48 2017 +0200
Committer: Chris Riccomini <criccom...@apache.org>
Committed: Thu Apr 27 12:35:46 2017 -0700

----------------------------------------------------------------------
 airflow/jobs.py      |   4 +-
 airflow/models.py    | 133 ++++++++++++++++++++++++++++++++++------------
 airflow/www/views.py |   7 +--
 tests/core.py        |  34 +++++++-----
 tests/models.py      |  66 ++++++++++++++++++++++-
 5 files changed, 190 insertions(+), 54 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e342d0d2/airflow/jobs.py
----------------------------------------------------------------------
diff --git a/airflow/jobs.py b/airflow/jobs.py
index 11dbddf..379c96e 100644
--- a/airflow/jobs.py
+++ b/airflow/jobs.py
@@ -1177,7 +1177,7 @@ class SchedulerJob(BaseJob):
             self._process_task_instances(dag, tis_out)
             self.manage_slas(dag)
 
-        models.DagStat.clean_dirty([d.dag_id for d in dags])
+        models.DagStat.update([d.dag_id for d in dags])
 
     def _process_executor_events(self):
         """
@@ -1977,7 +1977,7 @@ class BackfillJob(BaseJob):
                     active_dag_runs.remove(run)
 
                 if run.dag.is_paused:
-                    models.DagStat.clean_dirty([run.dag_id], session=session)
+                    models.DagStat.update([run.dag_id], session=session)
 
             msg = ' | '.join([
                 "[backfill progress]",

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e342d0d2/airflow/models.py
----------------------------------------------------------------------
diff --git a/airflow/models.py b/airflow/models.py
index 2de88f6..1ceb821 100755
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -29,6 +29,7 @@ import functools
 import getpass
 import imp
 import importlib
+import itertools
 import inspect
 import zipfile
 import jinja2
@@ -719,6 +720,7 @@ class TaskInstance(Base):
     even while multiple schedulers may be firing task instances.
     """
 
+
     __tablename__ = "task_instance"
 
     task_id = Column(String(ID_LEN), primary_key=True)
@@ -3089,7 +3091,7 @@ class DAG(BaseDag, LoggingMixin):
         for dr in drs:
             dr.state = state
             dirty_ids.append(dr.dag_id)
-        DagStat.clean_dirty(dirty_ids, session=session)
+        DagStat.update(dirty_ids, session=session)
 
     def clear(
             self, start_date=None, end_date=None,
@@ -3383,6 +3385,9 @@ class DAG(BaseDag, LoggingMixin):
             state=state
         )
         session.add(run)
+
+        DagStat.set_dirty(dag_id=self.dag_id, session=session)
+
         session.commit()
 
         run.dag = self
@@ -3392,12 +3397,7 @@ class DAG(BaseDag, LoggingMixin):
         run.verify_integrity(session=session)
 
         run.refresh_from_db()
-        DagStat.set_dirty(self.dag_id, session=session)
 
-        # add a placeholder row into DagStat table
-        if not session.query(DagStat).filter(DagStat.dag_id == 
self.dag_id).first():
-            session.add(DagStat(dag_id=self.dag_id, state=state, count=0, 
dirty=True))
-        session.commit()
         return run
 
     @staticmethod
@@ -3805,7 +3805,7 @@ class DagStat(Base):
     count = Column(Integer, default=0)
     dirty = Column(Boolean, default=False)
 
-    def __init__(self, dag_id, state, count, dirty=False):
+    def __init__(self, dag_id, state, count=0, dirty=False):
         self.dag_id = dag_id
         self.state = state
         self.count = count
@@ -3814,42 +3814,104 @@ class DagStat(Base):
     @staticmethod
     @provide_session
     def set_dirty(dag_id, session=None):
-        for dag in session.query(DagStat).filter(DagStat.dag_id == dag_id):
-            dag.dirty = True
-        session.commit()
+        """
+        :param dag_id: the dag_id to mark dirty
+        :param session: database session
+        :return: 
+        """
+        DagStat.create(dag_id=dag_id, session=session)
+
+        try:
+            stats = session.query(DagStat).filter(
+                DagStat.dag_id == dag_id
+            ).with_for_update().all()
+
+            for stat in stats:
+                stat.dirty = True
+            session.commit()
+        except Exception as e:
+            session.rollback()
+            logging.warning("Could not update dag stats for {}".format(dag_id))
+            logging.exception(e)
 
     @staticmethod
     @provide_session
-    def clean_dirty(dag_ids, session=None):
+    def update(dag_ids=None, dirty_only=True, session=None):
         """
-        Cleans out the dirty/out-of-sync rows from dag_stats table
+        Updates the stats for dirty/out-of-sync dags
 
-        :param dag_ids: dag_ids that may be dirty
+        :param dag_ids: dag_ids to be updated
         :type dag_ids: list
-        :param full_query: whether to check dag_runs for new drs not in 
dag_stats
-        :type full_query: bool
+        :param dirty_only: only updated for marked dirty, defaults to True
+        :type dirty_only: bool
+        :param session: db session to use
+        :type session: Session
         """
-        dag_ids = set(dag_ids)
+        if dag_ids is not None:
+            dag_ids = set(dag_ids)
 
-        qry = (
-            session.query(DagStat)
-            .filter(and_(DagStat.dag_id.in_(dag_ids), DagStat.dirty == True))
-        )
+        try:
+            qry = session.query(DagStat)
 
-        dirty_ids = {dag.dag_id for dag in qry.all()}
-        qry.delete(synchronize_session='fetch')
-        session.commit()
+            if dag_ids is not None:
+                qry = qry.filter(DagStat.dag_id.in_(dag_ids))
+            if dirty_only:
+                qry = qry.filter(DagStat.dirty == True)
 
-        qry = (
-            session.query(DagRun.dag_id, DagRun.state, func.count('*'))
-            .filter(DagRun.dag_id.in_(dirty_ids))
-            .group_by(DagRun.dag_id, DagRun.state)
-        )
+            qry = qry.with_for_update().all()
 
-        for dag_id, state, count in qry:
-            session.add(DagStat(dag_id=dag_id, state=state, count=count))
+            ids = set([dag_stat.dag_id for dag_stat in qry])
 
-        session.commit()
+            # avoid querying with an empty IN clause
+            if len(ids) == 0:
+                session.commit()
+                return
+
+            dagstat_states = set(itertools.product(ids, State.dag_states))
+            qry = (
+                session.query(DagRun.dag_id, DagRun.state, func.count('*'))
+                .filter(DagRun.dag_id.in_(ids))
+                .group_by(DagRun.dag_id, DagRun.state)
+            )
+
+            counts = {(dag_id, state): count for dag_id, state, count in qry}
+            for dag_id, state in dagstat_states:
+                count = 0
+                if (dag_id, state) in counts:
+                    count = counts[(dag_id, state)]
+
+                session.merge(
+                    DagStat(dag_id=dag_id, state=state, count=count, 
dirty=False)
+                )
+
+            session.commit()
+        except Exception as e:
+            session.rollback()
+            logging.warning("Could not update dag stat table")
+            logging.exception(e)
+
+    @staticmethod
+    @provide_session
+    def create(dag_id, session=None):
+        """
+        Creates the missing states the stats table for the dag specified 
+        
+        :param dag_id: dag id of the dag to create stats for
+        :param session: database session
+        :return: 
+        """
+        # unfortunately sqlalchemy does not know upsert
+        qry = session.query(DagStat).filter(DagStat.dag_id == dag_id).all()
+        states = [dag_stat.state for dag_stat in qry]
+        for state in State.dag_states:
+            if state not in states:
+                try:
+                    session.merge(DagStat(dag_id=dag_id, state=state))
+                    session.commit()
+                except Exception as e:
+                    session.rollback()
+                    logging.warning("Could not create stat record")
+                    logging.exception(e)
 
 
 class DagRun(Base):
@@ -3895,10 +3957,11 @@ class DagRun(Base):
     def set_state(self, state):
         if self._state != state:
             self._state = state
-            # something really weird goes on here: if you try to close the 
session
-            # dag runs will end up detached
-            session = settings.Session()
-            DagStat.set_dirty(self.dag_id, session=session)
+            if self.dag_id is not None:
+                # something really weird goes on here: if you try to close the 
session
+                # dag runs will end up detached
+                session = settings.Session()
+                DagStat.set_dirty(self.dag_id, session=session)
 
     @declared_attr
     def state(self):

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e342d0d2/airflow/www/views.py
----------------------------------------------------------------------
diff --git a/airflow/www/views.py b/airflow/www/views.py
index 53c6394..3ed58a8 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -468,6 +468,8 @@ class Airflow(BaseView):
         ds = models.DagStat
         session = Session()
 
+        ds.update()
+
         qry = (
             session.query(ds.dag_id, ds.state, ds.count)
         )
@@ -2247,9 +2249,8 @@ class DagRunModelView(ModelViewOnly):
         session.commit()
         dirty_ids = []
         for row in deleted:
-            models.DagStat.set_dirty(row.dag_id, session=session)
             dirty_ids.append(row.dag_id)
-        models.DagStat.clean_dirty(dirty_ids, session=session)
+        models.DagStat.update(dirty_ids, dirty_only=False, session=session)
         session.close()
 
     @action('set_running', "Set state to 'running'", None)
@@ -2279,7 +2280,7 @@ class DagRunModelView(ModelViewOnly):
                 else:
                     dr.end_date = datetime.now()
             session.commit()
-            models.DagStat.clean_dirty(dirty_ids, session=session)
+            models.DagStat.update(dirty_ids, session=session)
             flash(
                 "{count} dag runs were set to 
'{target_state}'".format(**locals()))
         except Exception as ex:

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e342d0d2/tests/core.py
----------------------------------------------------------------------
diff --git a/tests/core.py b/tests/core.py
index c36c6c2..4343cb3 100644
--- a/tests/core.py
+++ b/tests/core.py
@@ -979,40 +979,48 @@ class CoreTest(unittest.TestCase):
         session.query(models.DagStat).delete()
         session.commit()
 
+        models.DagStat.update([], session=session)
+
         run1 = self.dag_bash.create_dagrun(
             run_id="run1",
             execution_date=DEFAULT_DATE,
             state=State.RUNNING)
 
-        models.DagStat.clean_dirty([self.dag_bash.dag_id], session=session)
+        models.DagStat.update([self.dag_bash.dag_id], session=session)
 
         qry = session.query(models.DagStat).all()
 
-        assert len(qry) == 1
-        assert qry[0].dag_id == self.dag_bash.dag_id and\
-                qry[0].state == State.RUNNING and\
-                qry[0].count == 1 and\
-                qry[0].dirty == False
+        self.assertEqual(3, len(qry))
+        self.assertEqual(self.dag_bash.dag_id, qry[0].dag_id)
+        for stats in qry:
+            if stats.state == State.RUNNING:
+                self.assertEqual(stats.count, 1)
+            else:
+                self.assertEqual(stats.count, 0)
+            self.assertFalse(stats.dirty)
 
         run2 = self.dag_bash.create_dagrun(
             run_id="run2",
             execution_date=DEFAULT_DATE+timedelta(days=1),
             state=State.RUNNING)
 
-        models.DagStat.clean_dirty([self.dag_bash.dag_id], session=session)
+        models.DagStat.update([self.dag_bash.dag_id], session=session)
 
         qry = session.query(models.DagStat).all()
 
-        assert len(qry) == 1
-        assert qry[0].dag_id == self.dag_bash.dag_id and\
-                qry[0].state == State.RUNNING and\
-                qry[0].count == 2 and\
-                qry[0].dirty == False
+        self.assertEqual(3, len(qry))
+        self.assertEqual(self.dag_bash.dag_id, qry[0].dag_id)
+        for stats in qry:
+            if stats.state == State.RUNNING:
+                self.assertEqual(stats.count, 2)
+            else:
+                self.assertEqual(stats.count, 0)
+            self.assertFalse(stats.dirty)
 
         session.query(models.DagRun).first().state = State.SUCCESS
         session.commit()
 
-        models.DagStat.clean_dirty([self.dag_bash.dag_id], session=session)
+        models.DagStat.update([self.dag_bash.dag_id], session=session)
 
         qry = session.query(models.DagStat).filter(models.DagStat.state == 
State.SUCCESS).all()
         assert len(qry) == 1

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e342d0d2/tests/models.py
----------------------------------------------------------------------
diff --git a/tests/models.py b/tests/models.py
index 8223276..981561a 100644
--- a/tests/models.py
+++ b/tests/models.py
@@ -27,7 +27,7 @@ from airflow import models, settings, AirflowException
 from airflow.exceptions import AirflowSkipException
 from airflow.models import DAG, TaskInstance as TI
 from airflow.models import State as ST
-from airflow.models import DagModel
+from airflow.models import DagModel, DagStat
 from airflow.operators.dummy_operator import DummyOperator
 from airflow.operators.bash_operator import BashOperator
 from airflow.operators.python_operator import PythonOperator
@@ -234,6 +234,55 @@ class DagTest(unittest.TestCase):
         session.close()
 
 
+class DagStatTest(unittest.TestCase):
+    def test_dagstats_crud(self):
+        DagStat.create(dag_id='test_dagstats_crud')
+
+        session = settings.Session()
+        qry = session.query(DagStat).filter(DagStat.dag_id == 
'test_dagstats_crud')
+        self.assertEqual(len(qry.all()), len(State.dag_states))
+
+        DagStat.set_dirty(dag_id='test_dagstats_crud')
+        res = qry.all()
+
+        for stat in res:
+            self.assertTrue(stat.dirty)
+
+        # create missing
+        DagStat.set_dirty(dag_id='test_dagstats_crud_2')
+        qry2 = session.query(DagStat).filter(DagStat.dag_id == 
'test_dagstats_crud_2')
+        self.assertEqual(len(qry2.all()), len(State.dag_states))
+
+        dag = DAG(
+            'test_dagstats_crud',
+            start_date=DEFAULT_DATE,
+            default_args={'owner': 'owner1'})
+
+        with dag:
+            op1 = DummyOperator(task_id='A')
+
+        now = datetime.datetime.now()
+        dr = dag.create_dagrun(
+            run_id='manual__' + now.isoformat(),
+            execution_date=now,
+            start_date=now,
+            state=State.FAILED,
+            external_trigger=False,
+        )
+
+        DagStat.update(dag_ids=['test_dagstats_crud'])
+        res = qry.all()
+        for stat in res:
+            if stat.state == State.FAILED:
+                self.assertEqual(stat.count, 1)
+            else:
+                self.assertEqual(stat.count, 0)
+
+        DagStat.update()
+        res = qry2.all()
+        for stat in res:
+            self.assertFalse(stat.dirty)
+
 class DagRunTest(unittest.TestCase):
 
     def create_dag_run(self, dag, state=State.RUNNING, task_states=None):
@@ -412,6 +461,21 @@ class DagRunTest(unittest.TestCase):
         ti = dag_run.get_task_instance('test_short_circuit_false')
         self.assertEqual(None, ti)
 
+    def test_get_latest_runs(self):
+        session = settings.Session()
+        dag = DAG(
+            dag_id='test_latest_runs_1',
+            start_date=DEFAULT_DATE)
+        dag_1_run_1 = self.create_dag_run(dag,
+                execution_date=datetime.datetime(2015, 1, 1))
+        dag_1_run_2 = self.create_dag_run(dag,
+                execution_date=datetime.datetime(2015, 1, 2))
+        dagruns = models.DagRun.get_latest_runs(session)
+        session.close()
+        for dagrun in dagruns:
+            if dagrun.dag_id == 'test_latest_runs_1':
+                self.assertEqual(dagrun.execution_date, 
datetime.datetime(2015, 1, 2))
+
 
 class DagBagTest(unittest.TestCase):
 

Reply via email to