seanxwzhang commented on a change in pull request #8545:
URL: https://github.com/apache/airflow/pull/8545#discussion_r418254475



##########
File path: airflow/models/baseoperator.py
##########
@@ -392,10 +442,79 @@ def __init__(
                                    % (self.task_id, dag.dag_id))
         self.sla = sla
         self.execution_timeout = execution_timeout
+
+        # Warn about use of the deprecated SLA parameter
+        if sla and expected_finish:
+            warnings.warn(
+                "Both sla and expected_finish provided as task "
+                "parameters to {}; using expected_finish and ignoring "
+                "deprecated sla parameter.".format(self),
+                category=PendingDeprecationWarning
+            )
+        elif sla:
+            warnings.warn(
+                "sla is deprecated as a task parameter for {}; converting to "
+                "expected_finish instead.".format(self),
+                category=PendingDeprecationWarning
+            )
+            expected_finish = sla
+
+        # Set SLA parameters, batching invalid type messages into a
+        # single exception.
+        sla_param_errs: List = []
+        if expected_duration and not isinstance(expected_duration, timedelta):
+            sla_param_errs.append("expected_duration must be a timedelta, "
+                                  "got: {}".format(expected_duration))
+        if expected_start and not isinstance(expected_start, timedelta):
+            sla_param_errs.append("expected_start must be a timedelta, "
+                                  "got: {}".format(expected_start))
+        if expected_finish and not isinstance(expected_finish, timedelta):
+            sla_param_errs.append("expected_finish must be a timedelta, "
+                                  "got: {}".format(expected_finish))
+        if sla_param_errs:
+            raise AirflowException("Invalid SLA params were set! {}".format(
+                "; ".join(sla_param_errs)))
+
+        # If no exception has been raised, go ahead and set these.
+        self.expected_duration = expected_duration
+        self.expected_start = expected_start
+        self.expected_finish = expected_finish
+
+        # Warn the user if they've set any non-sensical parameter combinations
+        if self.expected_start and self.expected_finish \
+                and self.expected_start >= self.expected_finish:
+            self.log.warning(
+                "Task %s has an expected_start (%s) that occurs after its "
+                "expected_finish (%s), so it will always send an SLA "
+                "notification.",
+                self, self.expected_start, self.expected_finish
+            )
+
+            if self.expected_duration and self.expected_start \

Review comment:
       good catch!

##########
File path: airflow/models/dag.py
##########
@@ -1582,6 +1580,196 @@ def sync_to_db(self, sync_time=None, session=None):
         """
         self.bulk_sync_to_db([self], sync_time, session)
 
+    @provide_session
+    def manage_slas(self, session=None):
+        """
+        Helper function to encapsulate the sequence of SLA operations.
+        """
+        # Create SlaMiss objects for the various types of SLA misses.
+        self.record_sla_misses(session=session)
+
+        # Collect pending SLA miss callbacks, either created immediately above
+        # or previously failed.
+        unsent_sla_misses = self.get_unsent_sla_notifications(session=session)
+        self.log.debug("Found %s unsent SLA miss notifications",
+                       len(unsent_sla_misses))
+
+        # Trigger the SLA miss callbacks.
+        if unsent_sla_misses:
+            self.send_sla_notifications(unsent_sla_misses, session=session)
+
+    @provide_session
+    def record_sla_misses(self, session=None):
+        """
+        Create SLAMiss records for task instances associated with tasks in this
+        DAG. This involves walking forward to address potentially unscheduled
+        but expected executions, since new DAG runs may not get created if
+        there are concurrency restrictions on the scheduler. We still want to
+        receive SLA notifications in that scenario!
+        In the future, it would be preferable to have an SLA monitoring service
+        that runs independently from the scheduler, so that the service
+        responsible for scheduling work is not also responsible for determining
+        whether work is being scheduled.
+        """
+        self.log.debug("Checking for SLA misses for DAG %s", self.dag_id)
+
+        # Get all current DagRuns.
+        scheduled_dagruns = DagRun.find(
+            dag_id=self.dag_id,
+            # TODO related to AIRFLOW-2236: determine how SLA misses should
+            # work for backfills and externally triggered
+            # DAG runs. At minimum they could have duration SLA misses.
+            external_trigger=False,
+            no_backfills=True,
+            # We aren't passing in the "state" parameter because we care about
+            # checking for SLAs whether the DAG run has failed, succeeded, or
+            # is still running.
+            session=session
+        )
+
+        # TODO: Is there a better limit here than "look at most recent 100"?
+        # Perhaps there should be a configurable lookback window on the DAG,
+        # for how many runs to consider SLA violations for.
+        scheduled_dagruns = scheduled_dagruns[-100:]
+        scheduled_dagrun_ids = [d.id for d in scheduled_dagruns]
+
+        TI = TaskInstance
+        DR = DagRun
+
+        if scheduled_dagrun_ids:
+            # Find full, existing TIs for these DagRuns.
+            scheduled_tis = (
+                session.query(TI)
+                .outerjoin(DR, and_(
+                    DR.dag_id == TI.dag_id,
+                    DR.execution_date == TI.execution_date))
+                # Only look at TIs for this DAG.
+                .filter(TI.dag_id == self.dag_id)
+                # Only look at TIs that *still* exist in this DAG.
+                .filter(TI.task_id.in_(self.task_ids))
+                # Don't look for success/skip TIs. We check SLAs often, so
+                # there's little chance that a TI switches to successful
+                # after an SLA miss but before we notice; and this should
+                # be a major perf boost (since most TIs are successful or
+                # skipped).
+                .filter(or_(
+                    # has to be written this way to account for sql nulls
+                    TI.state == None, # noqa E711
+                    not_(TI.state.in_((State.SUCCESS, State.SKIPPED)))
+                ))
+                # Only look at specified DagRuns
+                .filter(DR.id.in_(scheduled_dagrun_ids))
+                # If the DAGRun is SUCCEEDED, then everything has gone
+                # according to plan. But if it's FAILED, someone may be
+                # coming to fix it, and SLAs for tasks in it will still
+                # matter.
+                .filter(DR.state != State.SUCCESS)
+                .order_by(asc(DR.execution_date))
+                .all()
+            )
+        else:
+            scheduled_tis = []
+
+        self.log.debug(
+            "Found {} outstanding TIs across {} dagruns for DAG {}".format(
+                len(scheduled_tis), len(scheduled_dagruns), self.dag_id))
+
+        # We need to examine unscheduled DAGRuns, too. If there are concurrency
+        # limitations, it's possible that a task instance will miss its SLA
+        # before its corresponding DAGRun even gets created.
+        last_dagrun = scheduled_dagruns[-1] if scheduled_dagruns else None
+
+        def unscheduled_tis(last_dagrun):
+            for dag_run in yield_unscheduled_runs(self, last_dagrun, ts):
+                for ti in yield_unscheduled_tis(dag_run, ts):
+                    yield ti
+
+        # Snapshot the time to check SLAs against.
+        ts = timezone.utcnow()
+
+        for ti in itertools.chain(scheduled_tis, unscheduled_tis(last_dagrun)):
+            ti.task = self.task_dict[ti.task_id]
+            # Ignore tasks that don't have SLAs, saving most calculation of
+            # future task instances.
+            if ti.task.has_slas():
+                create_sla_misses(ti, ts, session=session)
+
+        # Save any SlaMisses that were created in `create_sla_misses()`
+        session.commit()
+
+    @provide_session
+    def get_unsent_sla_notifications(self, session=None):
+        """
+        Find all SlaMisses for this DAG that haven't yet been notified.
+        """
+        return (
+            session
+            .query(SlaMiss)
+            .filter(SlaMiss.notification_sent == False)  # noqa
+            .filter(SlaMiss.dag_id == self.dag_id)
+            .all()
+        )
+
+    @provide_session
+    def send_sla_notifications(self, sla_misses, session=None):

Review comment:
       sure

##########
File path: airflow/models/dag.py
##########
@@ -1582,6 +1580,196 @@ def sync_to_db(self, sync_time=None, session=None):
         """
         self.bulk_sync_to_db([self], sync_time, session)
 
+    @provide_session
+    def manage_slas(self, session=None):
+        """
+        Helper function to encapsulate the sequence of SLA operations.
+        """
+        # Create SlaMiss objects for the various types of SLA misses.
+        self.record_sla_misses(session=session)
+
+        # Collect pending SLA miss callbacks, either created immediately above
+        # or previously failed.
+        unsent_sla_misses = self.get_unsent_sla_notifications(session=session)
+        self.log.debug("Found %s unsent SLA miss notifications",
+                       len(unsent_sla_misses))
+
+        # Trigger the SLA miss callbacks.
+        if unsent_sla_misses:
+            self.send_sla_notifications(unsent_sla_misses, session=session)
+
+    @provide_session
+    def record_sla_misses(self, session=None):
+        """
+        Create SLAMiss records for task instances associated with tasks in this
+        DAG. This involves walking forward to address potentially unscheduled
+        but expected executions, since new DAG runs may not get created if
+        there are concurrency restrictions on the scheduler. We still want to
+        receive SLA notifications in that scenario!
+        In the future, it would be preferable to have an SLA monitoring service
+        that runs independently from the scheduler, so that the service
+        responsible for scheduling work is not also responsible for determining
+        whether work is being scheduled.
+        """
+        self.log.debug("Checking for SLA misses for DAG %s", self.dag_id)
+
+        # Get all current DagRuns.
+        scheduled_dagruns = DagRun.find(
+            dag_id=self.dag_id,
+            # TODO related to AIRFLOW-2236: determine how SLA misses should
+            # work for backfills and externally triggered
+            # DAG runs. At minimum they could have duration SLA misses.
+            external_trigger=False,
+            no_backfills=True,
+            # We aren't passing in the "state" parameter because we care about
+            # checking for SLAs whether the DAG run has failed, succeeded, or
+            # is still running.
+            session=session
+        )
+
+        # TODO: Is there a better limit here than "look at most recent 100"?
+        # Perhaps there should be a configurable lookback window on the DAG,
+        # for how many runs to consider SLA violations for.
+        scheduled_dagruns = scheduled_dagruns[-100:]
+        scheduled_dagrun_ids = [d.id for d in scheduled_dagruns]
+
+        TI = TaskInstance
+        DR = DagRun
+
+        if scheduled_dagrun_ids:
+            # Find full, existing TIs for these DagRuns.

Review comment:
       not sure what that means either, I'll remove it.

##########
File path: airflow/models/dag.py
##########
@@ -1582,6 +1580,196 @@ def sync_to_db(self, sync_time=None, session=None):
         """
         self.bulk_sync_to_db([self], sync_time, session)
 
+    @provide_session
+    def manage_slas(self, session=None):
+        """
+        Helper function to encapsulate the sequence of SLA operations.
+        """
+        # Create SlaMiss objects for the various types of SLA misses.
+        self.record_sla_misses(session=session)
+
+        # Collect pending SLA miss callbacks, either created immediately above
+        # or previously failed.
+        unsent_sla_misses = self.get_unsent_sla_notifications(session=session)
+        self.log.debug("Found %s unsent SLA miss notifications",
+                       len(unsent_sla_misses))
+
+        # Trigger the SLA miss callbacks.
+        if unsent_sla_misses:
+            self.send_sla_notifications(unsent_sla_misses, session=session)
+
+    @provide_session
+    def record_sla_misses(self, session=None):
+        """
+        Create SLAMiss records for task instances associated with tasks in this
+        DAG. This involves walking forward to address potentially unscheduled
+        but expected executions, since new DAG runs may not get created if
+        there are concurrency restrictions on the scheduler. We still want to
+        receive SLA notifications in that scenario!
+        In the future, it would be preferable to have an SLA monitoring service
+        that runs independently from the scheduler, so that the service
+        responsible for scheduling work is not also responsible for determining
+        whether work is being scheduled.
+        """
+        self.log.debug("Checking for SLA misses for DAG %s", self.dag_id)
+
+        # Get all current DagRuns.
+        scheduled_dagruns = DagRun.find(
+            dag_id=self.dag_id,
+            # TODO related to AIRFLOW-2236: determine how SLA misses should
+            # work for backfills and externally triggered
+            # DAG runs. At minimum they could have duration SLA misses.
+            external_trigger=False,
+            no_backfills=True,
+            # We aren't passing in the "state" parameter because we care about
+            # checking for SLAs whether the DAG run has failed, succeeded, or
+            # is still running.
+            session=session
+        )
+
+        # TODO: Is there a better limit here than "look at most recent 100"?
+        # Perhaps there should be a configurable lookback window on the DAG,
+        # for how many runs to consider SLA violations for.
+        scheduled_dagruns = scheduled_dagruns[-100:]
+        scheduled_dagrun_ids = [d.id for d in scheduled_dagruns]
+
+        TI = TaskInstance
+        DR = DagRun
+
+        if scheduled_dagrun_ids:
+            # Find full, existing TIs for these DagRuns.
+            scheduled_tis = (
+                session.query(TI)
+                .outerjoin(DR, and_(
+                    DR.dag_id == TI.dag_id,
+                    DR.execution_date == TI.execution_date))
+                # Only look at TIs for this DAG.
+                .filter(TI.dag_id == self.dag_id)
+                # Only look at TIs that *still* exist in this DAG.
+                .filter(TI.task_id.in_(self.task_ids))
+                # Don't look for success/skip TIs. We check SLAs often, so
+                # there's little chance that a TI switches to successful
+                # after an SLA miss but before we notice; and this should
+                # be a major perf boost (since most TIs are successful or
+                # skipped).
+                .filter(or_(
+                    # has to be written this way to account for sql nulls
+                    TI.state == None, # noqa E711
+                    not_(TI.state.in_((State.SUCCESS, State.SKIPPED)))
+                ))
+                # Only look at specified DagRuns
+                .filter(DR.id.in_(scheduled_dagrun_ids))
+                # If the DAGRun is SUCCEEDED, then everything has gone
+                # according to plan. But if it's FAILED, someone may be
+                # coming to fix it, and SLAs for tasks in it will still
+                # matter.
+                .filter(DR.state != State.SUCCESS)
+                .order_by(asc(DR.execution_date))
+                .all()
+            )
+        else:
+            scheduled_tis = []
+
+        self.log.debug(
+            "Found {} outstanding TIs across {} dagruns for DAG {}".format(
+                len(scheduled_tis), len(scheduled_dagruns), self.dag_id))
+
+        # We need to examine unscheduled DAGRuns, too. If there are concurrency
+        # limitations, it's possible that a task instance will miss its SLA
+        # before its corresponding DAGRun even gets created.
+        last_dagrun = scheduled_dagruns[-1] if scheduled_dagruns else None
+
+        def unscheduled_tis(last_dagrun):
+            for dag_run in yield_unscheduled_runs(self, last_dagrun, ts):

Review comment:
       it's created in line 1688, the idea is to take a snapshot of the current 
time and use that for comparison, as opposed to getting timestamp on the go. 
I'm indifferent of these 2 approaches, happy to change it if there's a valid 
reason for either one.

##########
File path: airflow/models/dag.py
##########
@@ -1582,6 +1580,196 @@ def sync_to_db(self, sync_time=None, session=None):
         """
         self.bulk_sync_to_db([self], sync_time, session)
 
+    @provide_session
+    def manage_slas(self, session=None):
+        """
+        Helper function to encapsulate the sequence of SLA operations.
+        """
+        # Create SlaMiss objects for the various types of SLA misses.
+        self.record_sla_misses(session=session)
+
+        # Collect pending SLA miss callbacks, either created immediately above
+        # or previously failed.
+        unsent_sla_misses = self.get_unsent_sla_notifications(session=session)
+        self.log.debug("Found %s unsent SLA miss notifications",
+                       len(unsent_sla_misses))
+
+        # Trigger the SLA miss callbacks.
+        if unsent_sla_misses:
+            self.send_sla_notifications(unsent_sla_misses, session=session)
+
+    @provide_session
+    def record_sla_misses(self, session=None):
+        """
+        Create SLAMiss records for task instances associated with tasks in this
+        DAG. This involves walking forward to address potentially unscheduled
+        but expected executions, since new DAG runs may not get created if
+        there are concurrency restrictions on the scheduler. We still want to
+        receive SLA notifications in that scenario!
+        In the future, it would be preferable to have an SLA monitoring service
+        that runs independently from the scheduler, so that the service
+        responsible for scheduling work is not also responsible for determining
+        whether work is being scheduled.
+        """
+        self.log.debug("Checking for SLA misses for DAG %s", self.dag_id)
+
+        # Get all current DagRuns.
+        scheduled_dagruns = DagRun.find(
+            dag_id=self.dag_id,
+            # TODO related to AIRFLOW-2236: determine how SLA misses should
+            # work for backfills and externally triggered
+            # DAG runs. At minimum they could have duration SLA misses.
+            external_trigger=False,
+            no_backfills=True,
+            # We aren't passing in the "state" parameter because we care about
+            # checking for SLAs whether the DAG run has failed, succeeded, or
+            # is still running.
+            session=session
+        )
+
+        # TODO: Is there a better limit here than "look at most recent 100"?
+        # Perhaps there should be a configurable lookback window on the DAG,
+        # for how many runs to consider SLA violations for.
+        scheduled_dagruns = scheduled_dagruns[-100:]
+        scheduled_dagrun_ids = [d.id for d in scheduled_dagruns]
+
+        TI = TaskInstance
+        DR = DagRun
+
+        if scheduled_dagrun_ids:
+            # Find full, existing TIs for these DagRuns.
+            scheduled_tis = (
+                session.query(TI)
+                .outerjoin(DR, and_(
+                    DR.dag_id == TI.dag_id,
+                    DR.execution_date == TI.execution_date))
+                # Only look at TIs for this DAG.
+                .filter(TI.dag_id == self.dag_id)
+                # Only look at TIs that *still* exist in this DAG.
+                .filter(TI.task_id.in_(self.task_ids))
+                # Don't look for success/skip TIs. We check SLAs often, so
+                # there's little chance that a TI switches to successful
+                # after an SLA miss but before we notice; and this should
+                # be a major perf boost (since most TIs are successful or
+                # skipped).
+                .filter(or_(
+                    # has to be written this way to account for sql nulls
+                    TI.state == None, # noqa E711
+                    not_(TI.state.in_((State.SUCCESS, State.SKIPPED)))
+                ))
+                # Only look at specified DagRuns
+                .filter(DR.id.in_(scheduled_dagrun_ids))
+                # If the DAGRun is SUCCEEDED, then everything has gone
+                # according to plan. But if it's FAILED, someone may be
+                # coming to fix it, and SLAs for tasks in it will still
+                # matter.
+                .filter(DR.state != State.SUCCESS)
+                .order_by(asc(DR.execution_date))
+                .all()
+            )
+        else:
+            scheduled_tis = []
+
+        self.log.debug(
+            "Found {} outstanding TIs across {} dagruns for DAG {}".format(
+                len(scheduled_tis), len(scheduled_dagruns), self.dag_id))
+
+        # We need to examine unscheduled DAGRuns, too. If there are concurrency
+        # limitations, it's possible that a task instance will miss its SLA
+        # before its corresponding DAGRun even gets created.
+        last_dagrun = scheduled_dagruns[-1] if scheduled_dagruns else None
+
+        def unscheduled_tis(last_dagrun):
+            for dag_run in yield_unscheduled_runs(self, last_dagrun, ts):
+                for ti in yield_unscheduled_tis(dag_run, ts):
+                    yield ti
+
+        # Snapshot the time to check SLAs against.
+        ts = timezone.utcnow()
+
+        for ti in itertools.chain(scheduled_tis, unscheduled_tis(last_dagrun)):
+            ti.task = self.task_dict[ti.task_id]
+            # Ignore tasks that don't have SLAs, saving most calculation of
+            # future task instances.
+            if ti.task.has_slas():
+                create_sla_misses(ti, ts, session=session)
+
+        # Save any SlaMisses that were created in `create_sla_misses()`
+        session.commit()
+
+    @provide_session
+    def get_unsent_sla_notifications(self, session=None):
+        """
+        Find all SlaMisses for this DAG that haven't yet been notified.
+        """
+        return (
+            session
+            .query(SlaMiss)
+            .filter(SlaMiss.notification_sent == False)  # noqa
+            .filter(SlaMiss.dag_id == self.dag_id)
+            .all()
+        )
+
+    @provide_session
+    def send_sla_notifications(self, sla_misses, session=None):
+        """
+        Given a list of SLA misses, send emails and/or do SLA miss callback.
+        """
+        if not sla_misses:
+            self.log.warning("send_sla_notifications was called without any "
+                             "SLA notifications to send!")
+            return
+
+        # Retrieve the context for this TI, but patch in the SLA miss object.
+        for sla_miss in sla_misses:
+            if sla_miss.notification_sent:
+                self.log.debug("SLA miss %s has already had a notification 
sent, "
+                               "ignoring.", sla_miss)
+
+            TI = TaskInstance
+            ti = session.query(TI).filter(
+                TI.dag_id == sla_miss.dag_id,
+                TI.task_id == sla_miss.task_id,
+                TI.execution_date == sla_miss.execution_date,
+            ).all()
+
+            # Use the TI if found
+            task = self.get_task(sla_miss.task_id)
+            if ti:
+                ti = ti.pop()
+                ti.task = task
+            # Else make a temporary one.
+            else:
+                ti = TaskInstance(task, sla_miss.execution_date)
+                ti.task = task

Review comment:
       I believe this is for cases where SLAMiss is triggered on task instances 
that are yet to be scheduled. For example, since a TI has never been scheduled 
due to concurrency limits and the task has `expected_start` set, then there 
will be an SLAMiss created but there won't be a corresponding TI in database.

##########
File path: airflow/models/dag.py
##########
@@ -1582,6 +1580,196 @@ def sync_to_db(self, sync_time=None, session=None):
         """
         self.bulk_sync_to_db([self], sync_time, session)
 
+    @provide_session
+    def manage_slas(self, session=None):
+        """
+        Helper function to encapsulate the sequence of SLA operations.
+        """
+        # Create SlaMiss objects for the various types of SLA misses.
+        self.record_sla_misses(session=session)
+
+        # Collect pending SLA miss callbacks, either created immediately above
+        # or previously failed.
+        unsent_sla_misses = self.get_unsent_sla_notifications(session=session)
+        self.log.debug("Found %s unsent SLA miss notifications",
+                       len(unsent_sla_misses))
+
+        # Trigger the SLA miss callbacks.
+        if unsent_sla_misses:
+            self.send_sla_notifications(unsent_sla_misses, session=session)
+
+    @provide_session
+    def record_sla_misses(self, session=None):
+        """
+        Create SLAMiss records for task instances associated with tasks in this
+        DAG. This involves walking forward to address potentially unscheduled
+        but expected executions, since new DAG runs may not get created if
+        there are concurrency restrictions on the scheduler. We still want to
+        receive SLA notifications in that scenario!
+        In the future, it would be preferable to have an SLA monitoring service
+        that runs independently from the scheduler, so that the service
+        responsible for scheduling work is not also responsible for determining
+        whether work is being scheduled.
+        """
+        self.log.debug("Checking for SLA misses for DAG %s", self.dag_id)
+
+        # Get all current DagRuns.
+        scheduled_dagruns = DagRun.find(
+            dag_id=self.dag_id,
+            # TODO related to AIRFLOW-2236: determine how SLA misses should
+            # work for backfills and externally triggered
+            # DAG runs. At minimum they could have duration SLA misses.
+            external_trigger=False,
+            no_backfills=True,
+            # We aren't passing in the "state" parameter because we care about
+            # checking for SLAs whether the DAG run has failed, succeeded, or
+            # is still running.
+            session=session
+        )
+
+        # TODO: Is there a better limit here than "look at most recent 100"?
+        # Perhaps there should be a configurable lookback window on the DAG,
+        # for how many runs to consider SLA violations for.
+        scheduled_dagruns = scheduled_dagruns[-100:]
+        scheduled_dagrun_ids = [d.id for d in scheduled_dagruns]
+
+        TI = TaskInstance
+        DR = DagRun
+
+        if scheduled_dagrun_ids:
+            # Find full, existing TIs for these DagRuns.
+            scheduled_tis = (
+                session.query(TI)
+                .outerjoin(DR, and_(
+                    DR.dag_id == TI.dag_id,
+                    DR.execution_date == TI.execution_date))
+                # Only look at TIs for this DAG.
+                .filter(TI.dag_id == self.dag_id)
+                # Only look at TIs that *still* exist in this DAG.
+                .filter(TI.task_id.in_(self.task_ids))
+                # Don't look for success/skip TIs. We check SLAs often, so
+                # there's little chance that a TI switches to successful
+                # after an SLA miss but before we notice; and this should
+                # be a major perf boost (since most TIs are successful or
+                # skipped).
+                .filter(or_(
+                    # has to be written this way to account for sql nulls
+                    TI.state == None, # noqa E711
+                    not_(TI.state.in_((State.SUCCESS, State.SKIPPED)))
+                ))
+                # Only look at specified DagRuns
+                .filter(DR.id.in_(scheduled_dagrun_ids))
+                # If the DAGRun is SUCCEEDED, then everything has gone
+                # according to plan. But if it's FAILED, someone may be
+                # coming to fix it, and SLAs for tasks in it will still
+                # matter.
+                .filter(DR.state != State.SUCCESS)
+                .order_by(asc(DR.execution_date))
+                .all()
+            )
+        else:
+            scheduled_tis = []
+
+        self.log.debug(
+            "Found {} outstanding TIs across {} dagruns for DAG {}".format(
+                len(scheduled_tis), len(scheduled_dagruns), self.dag_id))
+
+        # We need to examine unscheduled DAGRuns, too. If there are concurrency
+        # limitations, it's possible that a task instance will miss its SLA
+        # before its corresponding DAGRun even gets created.
+        last_dagrun = scheduled_dagruns[-1] if scheduled_dagruns else None
+
+        def unscheduled_tis(last_dagrun):
+            for dag_run in yield_unscheduled_runs(self, last_dagrun, ts):
+                for ti in yield_unscheduled_tis(dag_run, ts):
+                    yield ti
+
+        # Snapshot the time to check SLAs against.
+        ts = timezone.utcnow()
+
+        for ti in itertools.chain(scheduled_tis, unscheduled_tis(last_dagrun)):
+            ti.task = self.task_dict[ti.task_id]
+            # Ignore tasks that don't have SLAs, saving most calculation of
+            # future task instances.
+            if ti.task.has_slas():
+                create_sla_misses(ti, ts, session=session)
+
+        # Save any SlaMisses that were created in `create_sla_misses()`
+        session.commit()
+
+    @provide_session
+    def get_unsent_sla_notifications(self, session=None):
+        """
+        Find all SlaMisses for this DAG that haven't yet been notified.
+        """
+        return (
+            session
+            .query(SlaMiss)
+            .filter(SlaMiss.notification_sent == False)  # noqa
+            .filter(SlaMiss.dag_id == self.dag_id)
+            .all()
+        )
+
+    @provide_session
+    def send_sla_notifications(self, sla_misses, session=None):
+        """
+        Given a list of SLA misses, send emails and/or do SLA miss callback.
+        """
+        if not sla_misses:
+            self.log.warning("send_sla_notifications was called without any "
+                             "SLA notifications to send!")
+            return
+
+        # Retrieve the context for this TI, but patch in the SLA miss object.
+        for sla_miss in sla_misses:
+            if sla_miss.notification_sent:
+                self.log.debug("SLA miss %s has already had a notification 
sent, "
+                               "ignoring.", sla_miss)
+
+            TI = TaskInstance
+            ti = session.query(TI).filter(
+                TI.dag_id == sla_miss.dag_id,
+                TI.task_id == sla_miss.task_id,
+                TI.execution_date == sla_miss.execution_date,
+            ).all()
+
+            # Use the TI if found
+            task = self.get_task(sla_miss.task_id)
+            if ti:
+                ti = ti.pop()
+                ti.task = task
+            # Else make a temporary one.
+            else:
+                ti = TaskInstance(task, sla_miss.execution_date)
+                ti.task = task
+
+            notification_sent = False
+            # If no callback exists, we don't want to send any notification;
+            # but we do want to update the SlaMiss in the database so that it
+            # doesn't keep looping.
+            if not task.sla_miss_callback:
+                notification_sent = True
+            else:
+                self.log.info("Calling sla_miss_callback for %s", sla_miss)
+                try:
+                    # Patch context with the current SLA miss.
+                    context = ti.get_template_context()
+                    context["sla_miss"] = sla_miss
+                    task.sla_miss_callback(context)
+                    notification_sent = True

Review comment:
       I'm not sure either, setting it directly makes more sense to me




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to