This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 631ac48  Some Pylint fixes in airflow/models/taskinstance.py (#9674)
631ac48 is described below

commit 631ac484f14e7f3f9637e1229252769d61e388e1
Author: Kaxil Naik <[email protected]>
AuthorDate: Mon Jul 6 20:32:02 2020 +0100

    Some Pylint fixes in airflow/models/taskinstance.py (#9674)
---
 airflow/executors/base_executor.py          |   3 +-
 airflow/jobs/base_job.py                    |  25 ++-
 airflow/models/taskinstance.py              | 308 ++++++++++++++++++----------
 airflow/operators/branch_operator.py        |   3 +-
 airflow/operators/python.py                 |   3 +-
 airflow/ti_deps/deps/dagrun_exists_dep.py   |   2 +-
 airflow/ti_deps/deps/ready_to_reschedule.py |   2 +-
 tests/models/test_skipmixin.py              |   4 +-
 8 files changed, 218 insertions(+), 132 deletions(-)

diff --git a/airflow/executors/base_executor.py 
b/airflow/executors/base_executor.py
index 2471c3f..935dd4a 100644
--- a/airflow/executors/base_executor.py
+++ b/airflow/executors/base_executor.py
@@ -21,8 +21,7 @@ from collections import OrderedDict
 from typing import Any, Dict, List, Optional, Set, Tuple, Union
 
 from airflow.configuration import conf
-from airflow.models import TaskInstance
-from airflow.models.taskinstance import SimpleTaskInstance, TaskInstanceKeyType
+from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance, 
TaskInstanceKeyType
 from airflow.stats import Stats
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.state import State
diff --git a/airflow/jobs/base_job.py b/airflow/jobs/base_job.py
index 6fe8bdb..e52a766 100644
--- a/airflow/jobs/base_job.py
+++ b/airflow/jobs/base_job.py
@@ -25,11 +25,12 @@ from sqlalchemy import Column, Index, Integer, String, and_
 from sqlalchemy.exc import OperationalError
 from sqlalchemy.orm.session import make_transient
 
-from airflow import models
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException
 from airflow.executors.executor_loader import ExecutorLoader
 from airflow.models.base import ID_LEN, Base
+from airflow.models.dagrun import DagRun
+from airflow.models.taskinstance import TaskInstance
 from airflow.stats import Stats
 from airflow.utils import helpers, timezone
 from airflow.utils.helpers import convert_camel_to_snake
@@ -268,22 +269,20 @@ class BaseJob(Base, LoggingMixin):
         running_tis = self.executor.running
 
         resettable_states = [State.SCHEDULED, State.QUEUED]
-        TI = models.TaskInstance
-        DR = models.DagRun
         if filter_by_dag_run is None:
             resettable_tis = (
                 session
-                .query(TI)
+                .query(TaskInstance)
                 .join(
-                    DR,
+                    DagRun,
                     and_(
-                        TI.dag_id == DR.dag_id,
-                        TI.execution_date == DR.execution_date))
+                        TaskInstance.dag_id == DagRun.dag_id,
+                        TaskInstance.execution_date == DagRun.execution_date))
                 .filter(
                     # pylint: disable=comparison-with-callable
-                    DR.state == State.RUNNING,
-                    DR.run_type != DagRunType.BACKFILL_JOB.value,
-                    TI.state.in_(resettable_states))).all()
+                    DagRun.state == State.RUNNING,
+                    DagRun.run_type != DagRunType.BACKFILL_JOB.value,
+                    TaskInstance.state.in_(resettable_states))).all()
         else:
             resettable_tis = 
filter_by_dag_run.get_task_instances(state=resettable_states,
                                                                   
session=session)
@@ -300,9 +299,9 @@ class BaseJob(Base, LoggingMixin):
             if not items:
                 return result
 
-            filter_for_tis = TI.filter_for_tis(items)
-            reset_tis = session.query(TI).filter(
-                filter_for_tis, TI.state.in_(resettable_states)
+            filter_for_tis = TaskInstance.filter_for_tis(items)
+            reset_tis = session.query(TaskInstance).filter(
+                filter_for_tis, TaskInstance.state.in_(resettable_states)
             ).with_for_update().all()
 
             for ti in reset_tis:
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 4d2a0d7..8360097 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -131,7 +131,7 @@ def clear_task_instances(tis,
 TaskInstanceKeyType = Tuple[str, str, datetime, int]
 
 
-class TaskInstance(Base, LoggingMixin):
+class TaskInstance(Base, LoggingMixin):     # pylint: disable=R0902,R0904
     """
     Task instances store the state of a task instance. This table is the
     authority and single source of truth around what tasks have run and the
@@ -180,6 +180,7 @@ class TaskInstance(Base, LoggingMixin):
     )
 
     def __init__(self, task, execution_date: datetime, state: Optional[str] = 
None):
+        super().__init__()
         self.dag_id = task.dag_id
         self.task_id = task.task_id
         self.task = task
@@ -209,6 +210,8 @@ class TaskInstance(Base, LoggingMixin):
         # Is this TaskInstance being currently running within `airflow tasks 
run --raw`.
         # Not persisted to the database so only valid for the current process
         self.raw = False
+        # can be changed when calling 'run'
+        self.test_mode = False
 
     @reconstructor
     def init_on_load(self):
@@ -249,9 +252,10 @@ class TaskInstance(Base, LoggingMixin):
 
     @property
     def next_try_number(self):
+        """Setting Next Try Number"""
         return self._try_number + 1
 
-    def command_as_list(
+    def command_as_list(    # pylint: disable=too-many-arguments
             self,
             mark_success=False,
             ignore_all_deps=False,
@@ -297,7 +301,7 @@ class TaskInstance(Base, LoggingMixin):
             cfg_path=cfg_path)
 
     @staticmethod
-    def generate_command(dag_id: str,
+    def generate_command(dag_id: str,     # pylint: disable=too-many-arguments
                          task_id: str,
                          execution_date: datetime,
                          mark_success: Optional[bool] = False,
@@ -383,6 +387,7 @@ class TaskInstance(Base, LoggingMixin):
 
     @property
     def log_filepath(self):
+        """Filepath for TaskInstance"""
         iso = self.execution_date.isoformat()
         log = os.path.expanduser(conf.get('logging', 'BASE_LOG_FOLDER'))
         return ("{log}/{dag_id}/{task_id}/{iso}.log".format(
@@ -390,6 +395,7 @@ class TaskInstance(Base, LoggingMixin):
 
     @property
     def log_url(self):
+        """Log URL for TaskInstance"""
         iso = quote(self.execution_date.isoformat())
         base_url = conf.get('webserver', 'BASE_URL')
         return base_url + (
@@ -401,6 +407,7 @@ class TaskInstance(Base, LoggingMixin):
 
     @property
     def mark_success_url(self):
+        """URL to mark TI success"""
         iso = quote(self.execution_date.isoformat())
         base_url = conf.get('webserver', 'BASE_URL')
         return base_url + (
@@ -418,6 +425,9 @@ class TaskInstance(Base, LoggingMixin):
         Get the very latest state from the database, if a session is passed,
         we use and looking up the state becomes part of the session, otherwise
         a new session is used.
+
+        :param session: SQLAlchemy ORM Session
+        :type session: Session
         """
         ti = session.query(TaskInstance).filter(
             TaskInstance.dag_id == self.dag_id,
@@ -434,6 +444,9 @@ class TaskInstance(Base, LoggingMixin):
     def error(self, session=None):
         """
         Forces the task instance's state to FAILED in the database.
+
+        :param session: SQLAlchemy ORM Session
+        :type session: Session
         """
         self.log.error("Recording the task instance as FAILED")
         self.state = State.FAILED
@@ -445,10 +458,14 @@ class TaskInstance(Base, LoggingMixin):
         """
         Refreshes the task instance from the database based on the primary key
 
+        :param session: SQLAlchemy ORM Session
+        :type session: Session
         :param lock_for_update: if True, indicates that the database should
             lock the TaskInstance (issuing a FOR UPDATE clause) until the
             session is committed.
+        :type lock_for_update: bool
         """
+        self.log.debug("Refreshing TaskInstance %s from DB", self)
 
         qry = session.query(TaskInstance).filter(
             TaskInstance.dag_id == self.dag_id,
@@ -467,7 +484,7 @@ class TaskInstance(Base, LoggingMixin):
             self.state = ti.state
             # Get the raw value of try_number column, don't read through the
             # accessor here otherwise it will be incremented by one already.
-            self.try_number = ti._try_number
+            self.try_number = ti._try_number    # pylint: 
disable=protected-access
             self.max_tries = ti.max_tries
             self.hostname = ti.hostname
             self.unixname = ti.unixname
@@ -482,6 +499,8 @@ class TaskInstance(Base, LoggingMixin):
         else:
             self.state = None
 
+        self.log.debug("Refreshed TaskInstance %s", self)
+
     def refresh_from_task(self, task, pool_override=None):
         """
         Copy common attributes from the given task.
@@ -504,13 +523,18 @@ class TaskInstance(Base, LoggingMixin):
     def clear_xcom_data(self, session=None):
         """
         Clears all XCom data from the database for the task instance
+
+        :param session: SQLAlchemy ORM Session
+        :type session: Session
         """
+        self.log.debug("Clearing XCom data")
         session.query(XCom).filter(
             XCom.dag_id == self.dag_id,
             XCom.task_id == self.task_id,
             XCom.execution_date == self.execution_date
         ).delete()
         session.commit()
+        self.log.debug("XCom data cleared")
 
     @property
     def key(self) -> TaskInstanceKeyType:
@@ -521,6 +545,17 @@ class TaskInstance(Base, LoggingMixin):
 
     @provide_session
     def set_state(self, state, session=None, commit=True):
+        """
+        Set TaskInstance state
+
+        :param state: State to set for the TI
+        :type state: str
+        :param session: SQLAlchemy ORM Session
+        :type session: Session
+        :param commit: Whether or not to commit session
+        :type commit: bool
+        """
+        self.log.debug("Setting task state for %s to %s", self, state)
         self.state = state
         self.start_date = timezone.utcnow()
         self.end_date = timezone.utcnow()
@@ -546,6 +581,9 @@ class TaskInstance(Base, LoggingMixin):
         This is useful when you do not want to start processing the next
         schedule of a task until the dependents are done. For instance,
         if the task DROPs and recreates a table.
+
+        :param session: SQLAlchemy ORM Session
+        :type session: Session
         """
         task = self.task
 
@@ -571,6 +609,7 @@ class TaskInstance(Base, LoggingMixin):
         The task instance for the task that ran before this task instance.
 
         :param state: If passed, it only take into account instances of a 
specific state.
+        :param session: SQLAlchemy ORM Session
         """
         dag = self.task.dag
         if dag:
@@ -643,6 +682,7 @@ class TaskInstance(Base, LoggingMixin):
         The execution date from property previous_ti_success.
 
         :param state: If passed, it only take into account instances of a 
specific state.
+        :param session: SQLAlchemy ORM Session
         """
         self.log.debug("previous_execution_date was called")
         prev_ti = self.get_previous_ti(state=state, session=session)
@@ -658,6 +698,7 @@ class TaskInstance(Base, LoggingMixin):
         The start date from property previous_ti_success.
 
         :param state: If passed, it only take into account instances of a 
specific state.
+        :param session: SQLAlchemy ORM Session
         """
         self.log.debug("previous_start_date was called")
         prev_ti = self.get_previous_ti(state=state, session=session)
@@ -723,6 +764,7 @@ class TaskInstance(Base, LoggingMixin):
             self,
             dep_context=None,
             session=None):
+        """Get failed Dependencies"""
         dep_context = dep_context or DepContext()
         for dep in dep_context.deps | self.task.deps:
             for dep_status in dep.get_dep_statuses(
@@ -756,13 +798,13 @@ class TaskInstance(Base, LoggingMixin):
             # will occurr in the modded_hash calculation.
             min_backoff = int(math.ceil(delay.total_seconds() * (2 ** 
(self.try_number - 2))))
             # deterministic per task instance
-            hash = int(hashlib.sha1("{}#{}#{}#{}".format(self.dag_id,
-                                                         self.task_id,
-                                                         self.execution_date,
-                                                         self.try_number)
-                                    .encode('utf-8')).hexdigest(), 16)
+            ti_hash = int(hashlib.sha1("{}#{}#{}#{}".format(self.dag_id,
+                                                            self.task_id,
+                                                            
self.execution_date,
+                                                            self.try_number)
+                                       .encode('utf-8')).hexdigest(), 16)
             # between 1 and 1.0 * delay * (2^retry_number)
-            modded_hash = min_backoff + hash % min_backoff
+            modded_hash = min_backoff + ti_hash % min_backoff
             # timedelta has a maximum representable value. The exponentiation
             # here means this value can be exceeded after a certain number
             # of tries (around 50 if the initial delay is 1s, even fewer if
@@ -786,11 +828,11 @@ class TaskInstance(Base, LoggingMixin):
                 self.next_retry_datetime() < timezone.utcnow())
 
     @provide_session
-    def get_dagrun(self, session=None):
+    def get_dagrun(self, session: Session = None):
         """
         Returns the DagRun for this TaskInstance
 
-        :param session:
+        :param session: SQLAlchemy ORM Session
         :return: DagRun
         """
         from airflow.models.dagrun import DagRun  # Avoid circular import
@@ -802,7 +844,7 @@ class TaskInstance(Base, LoggingMixin):
         return dr
 
     @provide_session
-    def check_and_change_state_before_execution(
+    def check_and_change_state_before_execution(    # pylint: 
disable=too-many-arguments
             self,
             verbose: bool = True,
             ignore_all_deps: bool = False,
@@ -833,11 +875,16 @@ class TaskInstance(Base, LoggingMixin):
         :type mark_success: bool
         :param test_mode: Doesn't record success or failure in the DB
         :type test_mode: bool
+        :param job_id: Job (BackfillJob / LocalTaskJob / SchedulerJob) ID
+        :type job_id: str
         :param pool: specifies the pool to use to run the task instance
         :type pool: str
+        :param session: SQLAlchemy ORM Session
+        :type session: Session
         :return: whether the state was changed to running or not
         :rtype: bool
         """
+
         task = self.task
         self.refresh_from_task(task, pool_override=pool)
         self.test_mode = test_mode
@@ -849,7 +896,7 @@ class TaskInstance(Base, LoggingMixin):
             Stats.incr('previously_succeeded', 1, 1)
 
         # TODO: Logging needs cleanup, not clear what is being printed
-        hr = "\n" + ("-" * 80)  # Line break
+        hr_line_break = "\n" + ("-" * 80)  # Line break
 
         if not mark_success:
             # Firstly find non-runnable and non-requeueable tis.
@@ -892,22 +939,22 @@ class TaskInstance(Base, LoggingMixin):
                     session=session,
                     verbose=True):
                 self.state = State.NONE
-                self.log.warning(hr)
+                self.log.warning(hr_line_break)
                 self.log.warning(
                     "Rescheduling due to concurrency limits reached "
                     "at task runtime. Attempt %s of "
                     "%s. State set to NONE.", self.try_number, self.max_tries 
+ 1
                 )
-                self.log.warning(hr)
+                self.log.warning(hr_line_break)
                 self.queued_dttm = timezone.utcnow()
                 session.merge(self)
                 session.commit()
                 return False
 
         # print status message
-        self.log.info(hr)
+        self.log.info(hr_line_break)
         self.log.info("Starting attempt %s of %s", self.try_number, 
self.max_tries + 1)
-        self.log.info(hr)
+        self.log.info(hr_line_break)
         self._try_number += 1
 
         if not test_mode:
@@ -957,9 +1004,9 @@ class TaskInstance(Base, LoggingMixin):
         :type test_mode: bool
         :param pool: specifies the pool to use to run the task instance
         :type pool: str
+        :param session: SQLAlchemy ORM Session
+        :type session: Session
         """
-        from airflow.models.renderedtifields import RenderedTaskInstanceFields 
as RTIF
-        from airflow.sensors.base_sensor_operator import BaseSensorOperator
 
         task = self.task
         self.test_mode = test_mode
@@ -974,80 +1021,7 @@ class TaskInstance(Base, LoggingMixin):
         try:
             if not mark_success:
                 context = self.get_template_context()
-
-                task_copy = task.prepare_for_execution()
-
-                # Sensors in `poke` mode can block execution of DAGs when 
running
-                # with single process executor, thus we change the mode 
to`reschedule`
-                # to allow parallel task being scheduled and executed
-                if isinstance(task_copy, BaseSensorOperator) and \
-                        conf.get('core', 'executor') == "DebugExecutor":
-                    self.log.warning("DebugExecutor changes sensor mode to 
'reschedule'.")
-                    task_copy.mode = 'reschedule'
-
-                self.task = task_copy
-
-                def signal_handler(signum, frame):
-                    self.log.error("Received SIGTERM. Terminating 
subprocesses.")
-                    task_copy.on_kill()
-                    raise AirflowException("Task received SIGTERM signal")
-                signal.signal(signal.SIGTERM, signal_handler)
-
-                # Don't clear Xcom until the task is certain to execute
-                self.clear_xcom_data()
-
-                start_time = time.time()
-
-                self.render_templates(context=context)
-                if STORE_SERIALIZED_DAGS:
-                    RTIF.write(RTIF(ti=self, render_templates=False), 
session=session)
-                    RTIF.delete_old_records(self.task_id, self.dag_id, 
session=session)
-
-                # Export context to make it available for operators to use.
-                airflow_context_vars = context_to_airflow_vars(context, 
in_env_var_format=True)
-                self.log.info("Exporting the following env vars:\n%s",
-                              '\n'.join(["{}={}".format(k, v)
-                                         for k, v in 
airflow_context_vars.items()]))
-                os.environ.update(airflow_context_vars)
-                task_copy.pre_execute(context=context)
-
-                try:
-                    if task.on_execute_callback:
-                        task.on_execute_callback(context)
-                except Exception as e3:
-                    self.log.error("Failed when executing execute callback")
-                    self.log.exception(e3)
-
-                # If a timeout is specified for the task, make it fail
-                # if it goes beyond
-                if task_copy.execution_timeout:
-                    try:
-                        with timeout(int(
-                                task_copy.execution_timeout.total_seconds())):
-                            result = task_copy.execute(context=context)
-                    except AirflowTaskTimeout:
-                        task_copy.on_kill()
-                        raise
-                else:
-                    result = task_copy.execute(context=context)
-
-                # If the task returns a result, push an XCom containing it
-                if task_copy.do_xcom_push and result is not None:
-                    self.xcom_push(key=XCOM_RETURN_KEY, value=result)
-
-                task_copy.post_execute(context=context, result=result)
-
-                end_time = time.time()
-                duration = end_time - start_time
-                Stats.timing(
-                    'dag.{dag_id}.{task_id}.duration'.format(
-                        dag_id=task_copy.dag_id,
-                        task_id=task_copy.task_id),
-                    duration)
-
-                Stats.incr('operator_successes_{}'.format(
-                    self.task.__class__.__name__), 1, 1)
-                Stats.incr('ti_successes')
+                self._prepare_and_execute_task_with_callbacks(context, 
session, task)
             self.refresh_from_db(lock_for_update=True)
             self.state = State.SUCCESS
         except AirflowSkipException as e:
@@ -1089,13 +1063,7 @@ class TaskInstance(Base, LoggingMixin):
         finally:
             Stats.incr('ti.finish.{}.{}.{}'.format(task.dag_id, task.task_id, 
self.state))
 
-        # Success callback
-        try:
-            if task.on_success_callback:
-                task.on_success_callback(context)
-        except Exception as e3:
-            self.log.error("Failed when executing success callback")
-            self.log.exception(e3)
+        self._run_success_callback(context, task)
 
         # Recording SUCCESS
         self.end_date = timezone.utcnow()
@@ -1114,8 +1082,108 @@ class TaskInstance(Base, LoggingMixin):
             session.merge(self)
         session.commit()
 
+    def _prepare_and_execute_task_with_callbacks(self, context, session, task):
+        """
+        Prepare Task for Execution
+        """
+        from airflow.models.renderedtifields import RenderedTaskInstanceFields 
as RTIF
+        from airflow.sensors.base_sensor_operator import BaseSensorOperator
+
+        task_copy = task.prepare_for_execution()
+        # Sensors in `poke` mode can block execution of DAGs when running
+        # with single process executor, thus we change the mode to`reschedule`
+        # to allow parallel task being scheduled and executed
+        if (
+            isinstance(task_copy, BaseSensorOperator) and
+            conf.get('core', 'executor') == "DebugExecutor"
+        ):
+            self.log.warning("DebugExecutor changes sensor mode to 
'reschedule'.")
+            task_copy.mode = 'reschedule'
+        self.task = task_copy
+
+        def signal_handler(signum, frame):  # pylint: disable=unused-argument
+            self.log.error("Received SIGTERM. Terminating subprocesses.")
+            task_copy.on_kill()
+            raise AirflowException("Task received SIGTERM signal")
+
+        signal.signal(signal.SIGTERM, signal_handler)
+
+        # Don't clear Xcom until the task is certain to execute
+        self.clear_xcom_data()
+        start_time = time.time()
+
+        self.render_templates(context=context)
+        if STORE_SERIALIZED_DAGS:
+            RTIF.write(RTIF(ti=self, render_templates=False), session=session)
+            RTIF.delete_old_records(self.task_id, self.dag_id, session=session)
+
+        # Export context to make it available for operators to use.
+        airflow_context_vars = context_to_airflow_vars(context, 
in_env_var_format=True)
+        self.log.info("Exporting the following env vars:\n%s",
+                      '\n'.join(["{}={}".format(k, v)
+                                 for k, v in airflow_context_vars.items()]))
+
+        os.environ.update(airflow_context_vars)
+
+        # Run pre_execute callback
+        task_copy.pre_execute(context=context)
+
+        # Run on_execute callback
+        self._run_execute_callback(context, task)
+
+        # Execute the task
+        result = self._execute_task(context, task_copy)
+
+        # Run post_execute callback
+        task_copy.post_execute(context=context, result=result)
+
+        end_time = time.time()
+        duration = end_time - start_time
+        
Stats.timing('dag.{dag_id}.{task_id}.duration'.format(dag_id=task_copy.dag_id,
+                                                              
task_id=task_copy.task_id),
+                     duration)
+        
Stats.incr('operator_successes_{}'.format(self.task.__class__.__name__), 1, 1)
+        Stats.incr('ti_successes')
+
+    def _run_success_callback(self, context, task):
+        """Functions that need to be run if Task is successful"""
+        # Success callback
+        try:
+            if task.on_success_callback:
+                task.on_success_callback(context)
+        except Exception as exc:  # pylint: disable=broad-except
+            self.log.error("Failed when executing success callback")
+            self.log.exception(exc)
+
+    def _execute_task(self, context, task_copy):
+        """Executes Task (optionally with a Timeout) and pushes Xcom results"""
+        # If a timeout is specified for the task, make it fail
+        # if it goes beyond
+        if task_copy.execution_timeout:
+            try:
+                with timeout(int(task_copy.execution_timeout.total_seconds())):
+                    result = task_copy.execute(context=context)
+            except AirflowTaskTimeout:
+                task_copy.on_kill()
+                raise
+        else:
+            result = task_copy.execute(context=context)
+        # If the task returns a result, push an XCom containing it
+        if task_copy.do_xcom_push and result is not None:
+            self.xcom_push(key=XCOM_RETURN_KEY, value=result)
+        return result
+
+    def _run_execute_callback(self, context, task):
+        """Functions that need to be run before a Task is executed"""
+        try:
+            if task.on_execute_callback:
+                task.on_execute_callback(context)
+        except Exception as exc:  # pylint: disable=broad-except
+            self.log.error("Failed when executing execute callback")
+            self.log.exception(exc)
+
     @provide_session
-    def run(
+    def run(  # pylint: disable=too-many-arguments
             self,
             verbose: bool = True,
             ignore_all_deps: bool = False,
@@ -1127,6 +1195,7 @@ class TaskInstance(Base, LoggingMixin):
             job_id: Optional[str] = None,
             pool: Optional[str] = None,
             session=None) -> None:
+        """Run TaskInstance"""
         res = self.check_and_change_state_before_execution(
             verbose=verbose,
             ignore_all_deps=ignore_all_deps,
@@ -1147,6 +1216,7 @@ class TaskInstance(Base, LoggingMixin):
                 session=session)
 
     def dry_run(self):
+        """Only Renders Templates for the TI"""
         task = self.task
         task_copy = task.prepare_for_execution()
         self.task = task_copy
@@ -1155,7 +1225,11 @@ class TaskInstance(Base, LoggingMixin):
         task_copy.dry_run()
 
     @provide_session
-    def _handle_reschedule(self, actual_start_date, reschedule_exception, 
test_mode=False, context=None,
+    def _handle_reschedule(self,
+                           actual_start_date,
+                           reschedule_exception,
+                           test_mode=False,
+                           context=None,    # pylint: disable=unused-argument
                            session=None):
         # Don't record reschedule request in test mode
         if test_mode:
@@ -1182,6 +1256,7 @@ class TaskInstance(Base, LoggingMixin):
 
     @provide_session
     def handle_failure(self, error, test_mode=None, context=None, 
force_fail=False, session=None):
+        """Handle Failure for the TaskInstance"""
         if test_mode is None:
             test_mode = self.test_mode
         if context is None:
@@ -1236,17 +1311,17 @@ class TaskInstance(Base, LoggingMixin):
         if email_for_state and task.email:
             try:
                 self.email_alert(error)
-            except Exception as e2:
+            except Exception as exec2:     # pylint: disable=broad-except
                 self.log.error('Failed to send email to: %s', task.email)
-                self.log.exception(e2)
+                self.log.exception(exec2)
 
         # Handling callbacks pessimistically
         if callback:
             try:
                 callback(context)
-            except Exception as e3:
+            except Exception as exec3:     # pylint: disable=broad-except
                 self.log.error("Failed at executing callback")
-                self.log.exception(e3)
+                self.log.exception(exec3)
 
         if not test_mode:
             session.merge(self)
@@ -1263,7 +1338,8 @@ class TaskInstance(Base, LoggingMixin):
         return ''
 
     @provide_session
-    def get_template_context(self, session=None) -> Dict[str, Any]:
+    def get_template_context(self, session=None) -> Dict[str, Any]:  # pylint: 
disable=too-many-locals
+        """Return TI Context"""
         task = self.task
         from airflow import macros
 
@@ -1352,8 +1428,9 @@ class TaskInstance(Base, LoggingMixin):
             @staticmethod
             def get(
                 item: str,
-                default_var: Any = Variable._Variable__NO_DEFAULT_SENTINEL,
+                default_var: Any = Variable._Variable__NO_DEFAULT_SENTINEL,   
# pylint: disable=W0212
             ):
+                """Get Airflow Variable value"""
                 return Variable.get(item, default_var=default_var)
 
         class VariableJsonAccessor:
@@ -1378,8 +1455,9 @@ class TaskInstance(Base, LoggingMixin):
             @staticmethod
             def get(
                 item: str,
-                default_var: Any = Variable._Variable__NO_DEFAULT_SENTINEL,
+                default_var: Any = Variable._Variable__NO_DEFAULT_SENTINEL,   
# pylint: disable=W0212
             ):
+                """Get Airflow Variable after deserializing JSON value"""
                 return Variable.get(item, default_var=default_var, 
deserialize_json=True)
 
         return {
@@ -1447,7 +1525,9 @@ class TaskInstance(Base, LoggingMixin):
             self.render_templates()
 
     def overwrite_params_with_dag_run_conf(self, params, dag_run):
+        """Overwrite Task Params with DagRun.conf"""
         if dag_run and dag_run.conf:
+            self.log.debug("Updating task params (%s) with DagRun.conf (%s)", 
params, dag_run.conf)
             params.update(dag_run.conf)
 
     def render_templates(self, context: Optional[Dict] = None) -> None:
@@ -1458,6 +1538,7 @@ class TaskInstance(Base, LoggingMixin):
         self.task.render_template_fields(context)
 
     def email_alert(self, exception):
+        """Send Email Alert with exception trace"""
         exception_html = str(exception).replace('\n', '<br>')
         jinja_context = self.get_template_context()
         # This function is called after changing the state
@@ -1495,7 +1576,7 @@ class TaskInstance(Base, LoggingMixin):
         html_content = render('html_content_template', default_html_content)
         try:
             send_email(self.task.email, subject, html_content)
-        except Exception:
+        except Exception:     # pylint: disable=broad-except
             default_html_content_err = (
                 'Try {{try_number}} out of {{max_tries + 1}}<br>'
                 'Exception:<br>Failed attempt to attach error logs<br>'
@@ -1508,10 +1589,12 @@ class TaskInstance(Base, LoggingMixin):
             send_email(self.task.email, subject, html_content_err)
 
     def set_duration(self) -> None:
+        """Set TI duration"""
         if self.end_date and self.start_date:
             self.duration = (self.end_date - self.start_date).total_seconds()
         else:
             self.duration = None
+        self.log.debug("Task Duration set to %s", self.duration)
 
     def xcom_push(
             self,
@@ -1545,7 +1628,7 @@ class TaskInstance(Base, LoggingMixin):
             dag_id=self.dag_id,
             execution_date=execution_date or self.execution_date)
 
-    def xcom_pull(
+    def xcom_pull(      # pylint: disable=inconsistent-return-statements
             self,
             task_ids: Optional[Union[str, Iterable[str]]] = None,
             dag_id: Optional[str] = None,
@@ -1605,6 +1688,7 @@ class TaskInstance(Base, LoggingMixin):
 
     @provide_session
     def get_num_running_task_instances(self, session):
+        """Return Number of running TIs from the DB"""
         # .count() is inefficient
         return session.query(func.count()).filter(
             TaskInstance.dag_id == self.dag_id,
diff --git a/airflow/operators/branch_operator.py 
b/airflow/operators/branch_operator.py
index 247d4cc..a465341 100644
--- a/airflow/operators/branch_operator.py
+++ b/airflow/operators/branch_operator.py
@@ -19,7 +19,8 @@
 
 from typing import Dict, Iterable, Union
 
-from airflow.models import BaseOperator, SkipMixin
+from airflow.models import BaseOperator
+from airflow.models.skipmixin import SkipMixin
 
 
 class BaseBranchOperator(BaseOperator, SkipMixin):
diff --git a/airflow/operators/python.py b/airflow/operators/python.py
index 7107e17..5bbc715 100644
--- a/airflow/operators/python.py
+++ b/airflow/operators/python.py
@@ -32,8 +32,9 @@ from typing import Any, Callable, Dict, Iterable, List, 
Optional, Tuple
 import dill
 
 from airflow.exceptions import AirflowException
-from airflow.models import BaseOperator, SkipMixin
+from airflow.models import BaseOperator
 from airflow.models.dag import DAG, DagContext
+from airflow.models.skipmixin import SkipMixin
 from airflow.models.xcom_arg import XComArg
 from airflow.utils.decorators import apply_defaults
 from airflow.utils.process_utils import execute_in_subprocess
diff --git a/airflow/ti_deps/deps/dagrun_exists_dep.py 
b/airflow/ti_deps/deps/dagrun_exists_dep.py
index 92c0ad8..b04daa8 100644
--- a/airflow/ti_deps/deps/dagrun_exists_dep.py
+++ b/airflow/ti_deps/deps/dagrun_exists_dep.py
@@ -34,7 +34,7 @@ class DagrunRunningDep(BaseTIDep):
         dagrun = ti.get_dagrun(session)
         if not dagrun:
             # The import is needed here to avoid a circular dependency
-            from airflow.models import DagRun
+            from airflow.models.dagrun import DagRun
             running_dagruns = DagRun.find(
                 dag_id=dag.dag_id,
                 state=State.RUNNING,
diff --git a/airflow/ti_deps/deps/ready_to_reschedule.py 
b/airflow/ti_deps/deps/ready_to_reschedule.py
index 57e7dee..5f3530c 100644
--- a/airflow/ti_deps/deps/ready_to_reschedule.py
+++ b/airflow/ti_deps/deps/ready_to_reschedule.py
@@ -16,7 +16,7 @@
 # specific language governing permissions and limitations
 # under the License.
 
-from airflow.models import TaskReschedule
+from airflow.models.taskreschedule import TaskReschedule
 from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
 from airflow.utils import timezone
 from airflow.utils.session import provide_session
diff --git a/tests/models/test_skipmixin.py b/tests/models/test_skipmixin.py
index 06ed651..b8ef213 100644
--- a/tests/models/test_skipmixin.py
+++ b/tests/models/test_skipmixin.py
@@ -23,7 +23,9 @@ from unittest.mock import Mock, patch
 import pendulum
 
 from airflow import settings
-from airflow.models import DAG, SkipMixin, TaskInstance as TI
+from airflow.models.dag import DAG
+from airflow.models.skipmixin import SkipMixin
+from airflow.models.taskinstance import TaskInstance as TI
 from airflow.operators.dummy_operator import DummyOperator
 from airflow.utils import timezone
 from airflow.utils.state import State

Reply via email to