Repository: incubator-airflow
Updated Branches:
  refs/heads/master 7e6e84385 -> a1f4227be


[AIRFLOW-1237] Fix IN-predicate sqlalchemy warning

Closes #2320 from skudriashev/airflow-1237


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/a1f4227b
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/a1f4227b
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/a1f4227b

Branch: refs/heads/master
Commit: a1f4227bee1a70531cfa90769149322513cb6f92
Parents: 7e6e843
Author: Stanislav Kudriashev <[email protected]>
Authored: Mon May 29 17:23:11 2017 +0200
Committer: Bolke de Bruin <[email protected]>
Committed: Mon May 29 17:23:11 2017 +0200

----------------------------------------------------------------------
 airflow/jobs.py                           |  9 +++--
 airflow/models.py                         | 30 +++++++-------
 airflow/operators/latest_only_operator.py | 56 +++++++++++++-------------
 tests/jobs.py                             |  1 +
 4 files changed, 50 insertions(+), 46 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/a1f4227b/airflow/jobs.py
----------------------------------------------------------------------
diff --git a/airflow/jobs.py b/airflow/jobs.py
index 0c724a0..adc4328 100644
--- a/airflow/jobs.py
+++ b/airflow/jobs.py
@@ -690,9 +690,12 @@ class SchedulerJob(BaseJob):
         :param known_file_paths: The list of existing files that are parsed 
for DAGs
         :type known_file_paths: list[unicode]
         """
-        session.query(models.ImportError).filter(
-            ~models.ImportError.filename.in_(known_file_paths)
-        ).delete(synchronize_session='fetch')
+        query = session.query(models.ImportError)
+        if known_file_paths:
+            query = query.filter(
+                ~models.ImportError.filename.in_(known_file_paths)
+            )
+        query.delete(synchronize_session='fetch')
         session.commit()
 
     @staticmethod

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/a1f4227b/airflow/models.py
----------------------------------------------------------------------
diff --git a/airflow/models.py b/airflow/models.py
index 9a075a8..30e18a4 100755
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -125,16 +125,16 @@ def clear_task_instances(tis, session, 
activate_dag_runs=True):
         #     session.merge(ti)
         else:
             session.delete(ti)
+
     if job_ids:
         from airflow.jobs import BaseJob as BJ
         for job in session.query(BJ).filter(BJ.id.in_(job_ids)).all():
             job.state = State.SHUTDOWN
-    if activate_dag_runs:
-        execution_dates = {ti.execution_date for ti in tis}
-        dag_ids = {ti.dag_id for ti in tis}
+
+    if activate_dag_runs and tis:
         drs = session.query(DagRun).filter(
-            DagRun.dag_id.in_(dag_ids),
-            DagRun.execution_date.in_(execution_dates),
+            DagRun.dag_id.in_({ti.dag_id for ti in tis}),
+            DagRun.execution_date.in_({ti.execution_date for ti in tis}),
         ).all()
         for dr in drs:
             dr.state = State.RUNNING
@@ -2374,7 +2374,7 @@ class BaseOperator(object):
 
         count = qry.count()
 
-        clear_task_instances(qry, session)
+        clear_task_instances(qry.all(), session)
 
         session.commit()
         session.close()
@@ -3165,9 +3165,11 @@ class DAG(BaseDag, LoggingMixin):
             # Crafting the right filter for dag_id and task_ids combo
             conditions = []
             for dag in self.subdags + [self]:
-                conditions.append(
-                    TI.dag_id.like(dag.dag_id) & TI.task_id.in_(dag.task_ids)
-                )
+                if dag.task_ids:
+                    conditions.append(
+                        TI.dag_id.like(dag.dag_id) &
+                        TI.task_id.in_(dag.task_ids)
+                    )
             tis = tis.filter(or_(*conditions))
         else:
             tis = session.query(TI).filter(TI.dag_id == self.dag_id)
@@ -3201,7 +3203,7 @@ class DAG(BaseDag, LoggingMixin):
             do_it = utils.helpers.ask_yesno(question)
 
         if do_it:
-            clear_task_instances(tis, session)
+            clear_task_instances(tis.all(), session)
             if reset_dag_runs:
                 self.set_dag_runs_state(session=session)
         else:
@@ -3917,14 +3919,10 @@ class DagStat(Base):
         :param session: db session to use
         :type session: Session
         """
-        if dag_ids is not None:
-            dag_ids = set(dag_ids)
-
         try:
             qry = session.query(DagStat)
-
-            if dag_ids is not None:
-                qry = qry.filter(DagStat.dag_id.in_(dag_ids))
+            if dag_ids:
+                qry = qry.filter(DagStat.dag_id.in_(set(dag_ids)))
             if dirty_only:
                 qry = qry.filter(DagStat.dirty == True)
 

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/a1f4227b/airflow/operators/latest_only_operator.py
----------------------------------------------------------------------
diff --git a/airflow/operators/latest_only_operator.py 
b/airflow/operators/latest_only_operator.py
index 9d5defb..f1d8085 100644
--- a/airflow/operators/latest_only_operator.py
+++ b/airflow/operators/latest_only_operator.py
@@ -49,38 +49,40 @@ class LatestOnlyOperator(BaseOperator):
 
         if not left_window < now <= right_window:
             logging.info('Not latest execution, skipping downstream.')
-            session = settings.Session()
+            downstream_task_ids = context['task'].downstream_task_ids
+            if downstream_task_ids:
+                session = settings.Session()
+                TI = TaskInstance
+                tis = session.query(TI).filter(
+                    TI.execution_date == context['ti'].execution_date,
+                    TI.task_id.in_(downstream_task_ids)
+                ).with_for_update().all()
 
-            TI = TaskInstance
-            tis = session.query(TI).filter(
-                TI.execution_date == context['ti'].execution_date,
-                TI.task_id.in_(context['task'].downstream_task_ids)
-            ).with_for_update().all()
+                for ti in tis:
+                    logging.info('Skipping task: %s', ti.task_id)
+                    ti.state = State.SKIPPED
+                    ti.start_date = now
+                    ti.end_date = now
+                    session.merge(ti)
 
-            for ti in tis:
-                logging.info('Skipping task: %s', ti.task_id)
-                ti.state = State.SKIPPED
-                ti.start_date = now
-                ti.end_date = now
-                session.merge(ti)
+                # this is defensive against dag runs that are not complete
+                for task in context['task'].downstream_list:
+                    if task.task_id in tis:
+                        continue
 
-            # this is defensive against dag runs that are not complete
-            for task in context['task'].downstream_list:
-                if task.task_id in tis:
-                    continue
+                    logging.warning("Task {} was not part of a dag run. "
+                                    "This should not happen."
+                                    .format(task))
+                    now = datetime.datetime.now()
+                    ti = TaskInstance(task, 
execution_date=context['ti'].execution_date)
+                    ti.state = State.SKIPPED
+                    ti.start_date = now
+                    ti.end_date = now
+                    session.merge(ti)
 
-                logging.warning("Task {} was not part of a dag run. "
-                                "This should not happen."
-                                .format(task))
-                now = datetime.datetime.now()
-                ti = TaskInstance(task, 
execution_date=context['ti'].execution_date)
-                ti.state = State.SKIPPED
-                ti.start_date = now
-                ti.end_date = now
-                session.merge(ti)
+                session.commit()
+                session.close()
 
-            session.commit()
-            session.close()
             logging.info('Done.')
         else:
             logging.info('Latest, allowing execution to proceed.')

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/a1f4227b/tests/jobs.py
----------------------------------------------------------------------
diff --git a/tests/jobs.py b/tests/jobs.py
index 9ebea15..b0763b9 100644
--- a/tests/jobs.py
+++ b/tests/jobs.py
@@ -529,6 +529,7 @@ class SchedulerJobTest(unittest.TestCase):
     def setUp(self):
         self.dagbag = DagBag()
         session = settings.Session()
+        session.query(models.DagRun).delete()
         session.query(models.ImportError).delete()
         session.commit()
 

Reply via email to