kaxil closed pull request #4035: [AIRFLOW-3190] Make flake8 compliant
URL: https://github.com/apache/incubator-airflow/pull/4035
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/.flake8 b/.flake8
new file mode 100644
index 0000000000..2723df1f10
--- /dev/null
+++ b/.flake8
@@ -0,0 +1,3 @@
+[flake8]
+max-line-length = 110
+ignore = E731
diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
index 90452d954b..f7f69ac0b3 100644
--- a/.github/PULL_REQUEST_TEMPLATE.md
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -31,4 +31,4 @@ Make sure you have checked _all_ steps below.
### Code Quality
-- [ ] Passes `git diff upstream/master -u -- "*.py" | flake8 --diff`
+- [ ] Passes `flake8`
diff --git a/airflow/contrib/hooks/gcp_dataproc_hook.py
b/airflow/contrib/hooks/gcp_dataproc_hook.py
index ca86f08795..e068d65cbe 100644
--- a/airflow/contrib/hooks/gcp_dataproc_hook.py
+++ b/airflow/contrib/hooks/gcp_dataproc_hook.py
@@ -78,8 +78,7 @@ def wait_for_done(self):
def raise_error(self, message=None):
job_state = self.job['status']['state']
# We always consider ERROR to be an error state.
- if ((self.job_error_states and job_state in self.job_error_states)
- or 'ERROR' == job_state):
+ if (self.job_error_states and job_state in self.job_error_states) or
'ERROR' == job_state:
ex_message = message or ("Google DataProc job has state: %s" %
job_state)
ex_details = (str(self.job['status']['details'])
if 'details' in self.job['status']
diff --git a/airflow/jobs.py b/airflow/jobs.py
index 3922939a86..0bcb131c72 100644
--- a/airflow/jobs.py
+++ b/airflow/jobs.py
@@ -203,10 +203,10 @@ def run(self):
self._execute()
# In case of max runs or max duration
self.state = State.SUCCESS
- except SystemExit as e:
+ except SystemExit:
# In case of ^C or SIGTERM
self.state = State.SUCCESS
- except Exception as e:
+ except Exception:
self.state = State.FAILED
raise
finally:
@@ -424,7 +424,7 @@ def start(self):
def terminate(self, sigkill=False):
"""
Terminate (and then kill) the process launched to process the file.
-
+
:param sigkill: whether to issue a SIGKILL if SIGTERM doesn't work.
:type sigkill: bool
"""
@@ -453,7 +453,7 @@ def pid(self):
def exit_code(self):
"""
After the process is finished, this can be called to get the return
code
-
+
:return: the exit code of the process
:rtype: int
"""
@@ -465,7 +465,7 @@ def exit_code(self):
def done(self):
"""
Check if the process launched to process this file is done.
-
+
:return: whether the process is finished running
:rtype: bool
"""
@@ -2033,7 +2033,7 @@ def _update_counters(self, ti_status):
"""
Updates the counters per state of the tasks that were running. Can
re-add
to tasks to run in case required.
-
+
:param ti_status: the internal status of the backfill job tasks
:type ti_status: BackfillJob._DagRunTaskStatus
"""
@@ -2078,7 +2078,7 @@ def _manage_executor_state(self, running):
"""
Checks if the executor agrees with the state of task instances
that are running
-
+
:param running: dict of key, task to verify
"""
executor = self.executor
@@ -2110,7 +2110,7 @@ def _get_dag_run(self, run_date, session=None):
Returns a dag run for the given run date, which will be matched to an
existing
dag run if available or create a new dag run otherwise. If the
max_active_runs
limit is reached, this function will return None.
-
+
:param run_date: the execution date for the dag run
:type run_date: datetime
:param session: the database session object
@@ -2170,7 +2170,7 @@ def _task_instances_for_dag_run(self, dag_run,
session=None):
"""
Returns a map of task instance key to task instance object for the
tasks to
run in the given dag run.
-
+
:param dag_run: the dag run to get the tasks from
:type dag_run: models.DagRun
:param session: the database session object
@@ -2236,7 +2236,7 @@ def _process_backfill_task_instances(self,
Process a set of task instances from a set of dag runs. Special
handling is done
to account for different task instance states that could be present
when running
them in a backfill process.
-
+
:param ti_status: the internal status of the job
:type ti_status: BackfillJob._DagRunTaskStatus
:param executor: the executor to run the task instances
@@ -2474,7 +2474,7 @@ def _execute_for_run_dates(self, run_dates, ti_status,
executor, pickle_id,
Computes the dag runs and their respective task instances for
the given run dates and executes the task instances.
Returns a list of execution dates of the dag runs that were executed.
-
+
:param run_dates: Execution dates for dag runs
:type run_dates: list
:param ti_status: internal BackfillJob status structure to tis track
progress
diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index 6b2f3a639a..3cff4bb4c9 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -268,7 +268,7 @@ def initdb(rbac=False):
merge_conn(
models.Connection(
conn_id='qubole_default', conn_type='qubole',
- host= 'localhost'))
+ host='localhost'))
merge_conn(
models.Connection(
conn_id='segment_default', conn_type='segment',
diff --git a/airflow/utils/log/file_processor_handler.py
b/airflow/utils/log/file_processor_handler.py
index f39dffe0c9..cc7a8bd843 100644
--- a/airflow/utils/log/file_processor_handler.py
+++ b/airflow/utils/log/file_processor_handler.py
@@ -116,7 +116,7 @@ def _symlink_latest_log_directory(self):
os.unlink(latest_log_directory_path)
os.symlink(log_directory, latest_log_directory_path)
elif (os.path.isdir(latest_log_directory_path) or
- os.path.isfile(latest_log_directory_path)):
+ os.path.isfile(latest_log_directory_path)):
logging.warning(
"%s already exists as a dir/file. Skip creating
symlink.",
latest_log_directory_path
diff --git a/airflow/utils/sqlalchemy.py b/airflow/utils/sqlalchemy.py
index 76c112785f..7e97371319 100644
--- a/airflow/utils/sqlalchemy.py
+++ b/airflow/utils/sqlalchemy.py
@@ -37,12 +37,10 @@
utc = pendulum.timezone('UTC')
-def setup_event_handlers(
- engine,
- reconnect_timeout_seconds,
- initial_backoff_seconds=0.2,
- max_backoff_seconds=120):
-
+def setup_event_handlers(engine,
+ reconnect_timeout_seconds,
+ initial_backoff_seconds=0.2,
+ max_backoff_seconds=120):
@event.listens_for(engine, "engine_connect")
def ping_connection(connection, branch):
"""
@@ -100,7 +98,6 @@ def ping_connection(connection, branch):
# restore "close with result"
connection.should_close_with_result =
save_should_close_with_result
-
@event.listens_for(engine, "connect")
def connect(dbapi_connection, connection_record):
connection_record.info['pid'] = os.getpid()
diff --git a/airflow/utils/trigger_rule.py b/airflow/utils/trigger_rule.py
index ae51d6a301..7fdcbc8ca8 100644
--- a/airflow/utils/trigger_rule.py
+++ b/airflow/utils/trigger_rule.py
@@ -31,6 +31,7 @@ class TriggerRule(object):
DUMMY = 'dummy'
_ALL_TRIGGER_RULES = {}
+
@classmethod
def is_valid(cls, trigger_rule):
return trigger_rule in cls.all_triggers()
diff --git a/airflow/utils/weight_rule.py b/airflow/utils/weight_rule.py
index b920ef4022..f34856be83 100644
--- a/airflow/utils/weight_rule.py
+++ b/airflow/utils/weight_rule.py
@@ -28,6 +28,7 @@ class WeightRule(object):
ABSOLUTE = 'absolute'
_ALL_WEIGHT_RULES = {}
+
@classmethod
def is_valid(cls, weight_rule):
return weight_rule in cls.all_weight_rules()
diff --git a/airflow/www/views.py b/airflow/www/views.py
index 0aef2281e7..a89e2847f3 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -682,12 +682,12 @@ def dag_details(self, session=None):
title = "DAG details"
TI = models.TaskInstance
- states = (
- session.query(TI.state, sqla.func.count(TI.dag_id))
- .filter(TI.dag_id == dag_id)
- .group_by(TI.state)
- .all()
- )
+ states = session\
+ .query(TI.state, sqla.func.count(TI.dag_id))\
+ .filter(TI.dag_id == dag_id)\
+ .group_by(TI.state)\
+ .all()
+
return self.render(
'airflow/dag_details.html',
dag=dag, title=title, states=states, State=State)
@@ -1192,12 +1192,12 @@ def dagrun_clear(self):
@provide_session
def blocked(self, session=None):
DR = models.DagRun
- dags = (
- session.query(DR.dag_id, sqla.func.count(DR.id))
- .filter(DR.state == State.RUNNING)
- .group_by(DR.dag_id)
- .all()
- )
+ dags = session\
+ .query(DR.dag_id, sqla.func.count(DR.id))\
+ .filter(DR.state == State.RUNNING)\
+ .group_by(DR.dag_id)\
+ .all()
+
payload = []
for dag_id, active_dag_runs in dags:
max_active_runs = 0
@@ -1454,8 +1454,8 @@ def recurse_nodes(task, visited):
children_key = "_children"
def set_duration(tid):
- if (isinstance(tid, dict) and tid.get("state") ==
State.RUNNING and
- tid["start_date"] is not None):
+ if isinstance(tid, dict) and tid.get("state") == State.RUNNING
\
+ and tid["start_date"] is not None:
d = timezone.utcnow() - pendulum.parse(tid["start_date"])
tid["duration"] = d.total_seconds()
return tid
@@ -1482,9 +1482,7 @@ def set_duration(tid):
data = {
'name': '[DAG]',
'children': [recurse_nodes(t, set()) for t in dag.roots],
- 'instances': [
- dag_runs.get(d) or {'execution_date': d.isoformat()}
- for d in dates],
+ 'instances': [dag_runs.get(d) or {'execution_date': d.isoformat()}
for d in dates],
}
# minimize whitespace as this can be huge for bigger dags
@@ -2338,13 +2336,10 @@ class SlaMissModelView(wwwutils.SuperUserMixin,
ModelViewOnly):
@provide_session
def _connection_ids(session=None):
- return [
- (c.conn_id, c.conn_id)
- for c in (
- session.query(models.Connection.conn_id)
- .group_by(models.Connection.conn_id)
- )
- ]
+ return [(c.conn_id, c.conn_id) for c in (
+ session
+ .query(models.Connection.conn_id)
+ .group_by(models.Connection.conn_id))]
class ChartModelView(wwwutils.DataProfilingMixin, AirflowModelView):
@@ -3144,20 +3139,16 @@ def get_query(self):
"""
Default filters for model
"""
- return (
- super(DagModelView, self)
- .get_query()
- .filter(or_(models.DagModel.is_active,
models.DagModel.is_paused))
- .filter(~models.DagModel.is_subdag)
- )
+ return super(DagModelView, self)\
+ .get_query()\
+ .filter(or_(models.DagModel.is_active, models.DagModel.is_paused))\
+ .filter(~models.DagModel.is_subdag)
def get_count_query(self):
"""
Default filters for model
"""
- return (
- super(DagModelView, self)
- .get_count_query()
- .filter(models.DagModel.is_active)
- .filter(~models.DagModel.is_subdag)
- )
+ return super(DagModelView, self)\
+ .get_count_query()\
+ .filter(models.DagModel.is_active)\
+ .filter(~models.DagModel.is_subdag)
diff --git a/airflow/www_rbac/views.py b/airflow/www_rbac/views.py
index e6e505c41a..e7db7c651c 100644
--- a/airflow/www_rbac/views.py
+++ b/airflow/www_rbac/views.py
@@ -2111,7 +2111,7 @@ def varimport(self):
suc_count += 1
flash("{} variable(s) successfully updated.".format(suc_count))
if fail_count:
- flash("{} variables(s) failed to be
updated.".format(fail_count), 'error')
+ flash("{} variable(s) failed to be
updated.".format(fail_count), 'error')
self.update_redirect()
return redirect(self.get_redirect())
@@ -2353,7 +2353,7 @@ def action_clear(self, tis, session=None):
self.update_redirect()
return redirect(self.get_redirect())
- except Exception as ex:
+ except Exception:
flash('Failed to clear task instances', 'error')
@provide_session
diff --git a/scripts/ci/flake8-diff.sh b/scripts/ci/flake8-diff.sh
deleted file mode 100755
index 376be9bc0f..0000000000
--- a/scripts/ci/flake8-diff.sh
+++ /dev/null
@@ -1,164 +0,0 @@
-#!/usr/bin/env bash
-
-# Copyright (c) 2007–2017 The scikit-learn developers.
-# All rights reserved.
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are met:
-#
-# 1. Redistributions of source code must retain the above copyright notice,
this
-# list of conditions and the following disclaimer.
-# 2. Redistributions in binary form must reproduce the above copyright notice,
-# this list of conditions and the following disclaimer in the documentation
-# and/or other materials provided with the distribution.
-# 3. Neither the name of the Scikit-learn Developers nor the names of
-# its contributors may be used to endorse or promote products
-# derived from this software without specific prior written
-# permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND
-# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
-# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
-# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
FOR
-# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES
-# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
-# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
-# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
-# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-# This script is used in Travis to check that PRs do not add obvious
-# flake8 violations. It relies on two things:
-# - find common ancestor between branch and
-# apache/incubator-airflow remote
-# - run flake8 --diff on the diff between the branch and the common
-# ancestor
-#
-# Additional features:
-# - the line numbers in Travis match the local branch on the PR
-# author machine.
-# - ./build_tools/travis/flake8_diff.sh can be run locally for quick
-# turn-around
-
-set -e
-# pipefail is necessary to propagate exit codes
-set -o pipefail
-
-PROJECT=apache/incubator-airflow
-PROJECT_URL=https://github.com/$PROJECT.git
-
-# Find the remote with the project name (upstream in most cases)
-REMOTE=$(git remote -v | grep $PROJECT | cut -f1 | head -1 || echo '')
-
-# Add a temporary remote if needed. For example this is necessary when
-# Travis is configured to run in a fork. In this case 'origin' is the
-# fork and not the reference repo we want to diff against.
-if [[ -z "$REMOTE" ]]; then
- TMP_REMOTE=tmp_reference_upstream
- REMOTE=$TMP_REMOTE
- git remote add $REMOTE $PROJECT_URL
-fi
-
-echo "Remotes:"
-echo
'--------------------------------------------------------------------------------'
-git remote --verbose
-
-# Travis does the git clone with a limited depth (50 at the time of
-# writing). This may not be enough to find the common ancestor with
-# $REMOTE/master so we unshallow the git checkout
-if [[ -a .git/shallow ]]; then
- echo -e '\nTrying to unshallow the repo:'
- echo
'--------------------------------------------------------------------------------'
- git fetch --unshallow
-fi
-
-if [[ "$TRAVIS" == "true" ]]; then
- if [[ "$TRAVIS_PULL_REQUEST" == "false" ]]
- then
- # In main repo, using TRAVIS_COMMIT_RANGE to test the commits
- # that were pushed into a branch
- if [[ "$PROJECT" == "$TRAVIS_REPO_SLUG" ]]; then
- if [[ -z "$TRAVIS_COMMIT_RANGE" ]]; then
- echo "New branch, no commit range from Travis so passing this
test by convention"
- exit 0
- fi
- COMMIT_RANGE=$TRAVIS_COMMIT_RANGE
- fi
- else
- # We want to fetch the code as it is in the PR branch and not
- # the result of the merge into master. This way line numbers
- # reported by Travis will match with the local code.
- LOCAL_BRANCH_REF=travis_pr_$TRAVIS_PULL_REQUEST
- # In Travis the PR target is always origin
- git fetch origin pull/$TRAVIS_PULL_REQUEST/head:refs/$LOCAL_BRANCH_REF
- fi
-fi
-
-# If not using the commit range from Travis we need to find the common
-# ancestor between $LOCAL_BRANCH_REF and $REMOTE/master
-if [[ -z "$COMMIT_RANGE" ]]; then
- if [[ -z "$LOCAL_BRANCH_REF" ]]; then
- LOCAL_BRANCH_REF=$(git rev-parse --abbrev-ref HEAD)
- fi
- echo -e "\nLast 2 commits in $LOCAL_BRANCH_REF:"
- echo
'--------------------------------------------------------------------------------'
- git --no-pager log -2 $LOCAL_BRANCH_REF
-
- REMOTE_MASTER_REF="$REMOTE/master"
- # Make sure that $REMOTE_MASTER_REF is a valid reference
- echo -e "\nFetching $REMOTE_MASTER_REF"
- echo
'--------------------------------------------------------------------------------'
- git fetch $REMOTE master:refs/remotes/$REMOTE_MASTER_REF
- LOCAL_BRANCH_SHORT_HASH=$(git rev-parse --short $LOCAL_BRANCH_REF)
- REMOTE_MASTER_SHORT_HASH=$(git rev-parse --short $REMOTE_MASTER_REF)
-
- COMMIT=$(git merge-base $LOCAL_BRANCH_REF $REMOTE_MASTER_REF) || \
- echo "No common ancestor found for $(git show $LOCAL_BRANCH_REF -q)
and $(git show $REMOTE_MASTER_REF -q)"
-
- if [ -z "$COMMIT" ]; then
- exit 1
- fi
-
- COMMIT_SHORT_HASH=$(git rev-parse --short $COMMIT)
-
- echo -e "\nCommon ancestor between $LOCAL_BRANCH_REF
($LOCAL_BRANCH_SHORT_HASH)"\
- "and $REMOTE_MASTER_REF ($REMOTE_MASTER_SHORT_HASH) is
$COMMIT_SHORT_HASH:"
- echo
'--------------------------------------------------------------------------------'
- git --no-pager show --no-patch $COMMIT_SHORT_HASH
-
- COMMIT_RANGE="$COMMIT_SHORT_HASH..$LOCAL_BRANCH_SHORT_HASH"
-
- if [[ -n "$TMP_REMOTE" ]]; then
- git remote remove $TMP_REMOTE
- fi
-
-else
- echo "Got the commit range from Travis: $COMMIT_RANGE"
-fi
-
-echo -e '\nRunning flake8 on the diff in the range' "$COMMIT_RANGE" \
- "($(git rev-list $COMMIT_RANGE | wc -l) commit(s)):"
-echo
'--------------------------------------------------------------------------------'
-
-MODIFIED_FILES="$(git diff --name-only $COMMIT_RANGE || echo "no_match")"
-
-check_files() {
- files="$1"
- shift
- options="$*"
- if [ -n "$files" ]; then
- # Conservative approach: diff without context (--unified=0) so that
code
- # that was not changed does not create failures
- git diff --unified=0 $COMMIT_RANGE -- $files | flake8 --diff
--show-source $options
- fi
-}
-
-if [[ "$MODIFIED_FILES" == "no_match" ]]; then
- echo "No file outside ignored locations has been modified"
-else
-
- check_files "$(echo "$MODIFIED_FILES" | grep -v ^examples)"
- check_files "$(echo "$MODIFIED_FILES" | grep ^examples)" \
- --config ./examples/.flake8
-fi
-echo -e "No problem detected by flake8\n"
diff --git a/setup.py b/setup.py
index 76f55ab01b..b1376bb5bb 100644
--- a/setup.py
+++ b/setup.py
@@ -45,8 +45,7 @@ def verify_gpl_dependency():
if os.getenv("READTHEDOCS") == "True":
os.environ["SLUGIFY_USES_TEXT_UNIDECODE"] = "yes"
- if (not os.getenv("AIRFLOW_GPL_UNIDECODE")
- and not os.getenv("SLUGIFY_USES_TEXT_UNIDECODE") == "yes"):
+ if not os.getenv("AIRFLOW_GPL_UNIDECODE") and not
os.getenv("SLUGIFY_USES_TEXT_UNIDECODE") == "yes":
raise RuntimeError("By default one of Airflow's dependencies installs
a GPL "
"dependency (unidecode). To avoid this dependency
set "
"SLUGIFY_USES_TEXT_UNIDECODE=yes in your
environment when you "
diff --git a/tests/api/__init__.py b/tests/api/__init__.py
index 4067cc78ee..114d189da1 100644
--- a/tests/api/__init__.py
+++ b/tests/api/__init__.py
@@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
diff --git a/tests/api/client/__init__.py b/tests/api/client/__init__.py
index 4067cc78ee..114d189da1 100644
--- a/tests/api/client/__init__.py
+++ b/tests/api/client/__init__.py
@@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
diff --git a/tests/api/common/__init__.py b/tests/api/common/__init__.py
index 4067cc78ee..114d189da1 100644
--- a/tests/api/common/__init__.py
+++ b/tests/api/common/__init__.py
@@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
diff --git a/tests/api/common/experimental/__init__.py
b/tests/api/common/experimental/__init__.py
index 4067cc78ee..114d189da1 100644
--- a/tests/api/common/experimental/__init__.py
+++ b/tests/api/common/experimental/__init__.py
@@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
diff --git a/tests/api/common/experimental/mark_tasks.py
b/tests/api/common/experimental/mark_tasks.py
index 9bba91bee0..304e261b98 100644
--- a/tests/api/common/experimental/mark_tasks.py
+++ b/tests/api/common/experimental/mark_tasks.py
@@ -72,7 +72,7 @@ def tearDown(self):
def snapshot_state(self, dag, execution_dates):
TI = models.TaskInstance
tis = self.session.query(TI).filter(
- TI.dag_id==dag.dag_id,
+ TI.dag_id == dag.dag_id,
TI.execution_date.in_(execution_dates)
).all()
@@ -84,7 +84,7 @@ def verify_state(self, dag, task_ids, execution_dates, state,
old_tis):
TI = models.TaskInstance
tis = self.session.query(TI).filter(
- TI.dag_id==dag.dag_id,
+ TI.dag_id == dag.dag_id,
TI.execution_date.in_(execution_dates)
).all()
@@ -95,9 +95,8 @@ def verify_state(self, dag, task_ids, execution_dates, state,
old_tis):
self.assertEqual(ti.state, state)
else:
for old_ti in old_tis:
- if (old_ti.task_id == ti.task_id
- and old_ti.execution_date == ti.execution_date):
- self.assertEqual(ti.state, old_ti.state)
+ if old_ti.task_id == ti.task_id and old_ti.execution_date
== ti.execution_date:
+ self.assertEqual(ti.state, old_ti.state)
def test_mark_tasks_now(self):
# set one task to success but do not commit
@@ -435,19 +434,19 @@ def test_set_state_without_commit(self):
self._verify_task_instance_states_remain_default(dr)
def test_set_state_with_multiple_dagruns(self):
- dr1 = self.dag2.create_dagrun(
+ self.dag2.create_dagrun(
run_id='manual__' + datetime.now().isoformat(),
state=State.FAILED,
execution_date=self.execution_dates[0],
session=self.session
)
- dr2 = self.dag2.create_dagrun(
+ self.dag2.create_dagrun(
run_id='manual__' + datetime.now().isoformat(),
state=State.FAILED,
execution_date=self.execution_dates[1],
session=self.session
)
- dr3 = self.dag2.create_dagrun(
+ self.dag2.create_dagrun(
run_id='manual__' + datetime.now().isoformat(),
state=State.RUNNING,
execution_date=self.execution_dates[2],
@@ -468,13 +467,11 @@ def count_dag_tasks(dag):
self._verify_dag_run_state(self.dag2, self.execution_dates[1],
State.SUCCESS)
# Make sure other dag status are not changed
- dr1 = models.DagRun.find(dag_id=self.dag2.dag_id,
- execution_date=self.execution_dates[0])
- dr1 = dr1[0]
+ models.DagRun.find(dag_id=self.dag2.dag_id,
+ execution_date=self.execution_dates[0])
self._verify_dag_run_state(self.dag2, self.execution_dates[0],
State.FAILED)
- dr3 = models.DagRun.find(dag_id=self.dag2.dag_id,
- execution_date=self.execution_dates[2])
- dr3 = dr3[0]
+ models.DagRun.find(dag_id=self.dag2.dag_id,
+ execution_date=self.execution_dates[2])
self._verify_dag_run_state(self.dag2, self.execution_dates[2],
State.RUNNING)
def test_set_dag_run_state_edge_cases(self):
diff --git a/tests/contrib/hooks/test_databricks_hook.py
b/tests/contrib/hooks/test_databricks_hook.py
index 090f46caeb..597c881929 100644
--- a/tests/contrib/hooks/test_databricks_hook.py
+++ b/tests/contrib/hooks/test_databricks_hook.py
@@ -127,6 +127,7 @@ def terminate_cluster_endpoint(host):
"""
return 'https://{}/api/2.0/clusters/delete'.format(host)
+
def create_valid_response_mock(content):
response = mock.MagicMock()
response.json.return_value = content
@@ -143,13 +144,11 @@ def create_post_side_effect(exception, status_code=500):
return response
-def setup_mock_requests(
- mock_requests,
- exception,
- status_code=500,
- error_count=None,
- response_content=None):
-
+def setup_mock_requests(mock_requests,
+ exception,
+ status_code=500,
+ error_count=None,
+ response_content=None):
side_effect = create_post_side_effect(exception, status_code)
if error_count is None:
@@ -165,6 +164,7 @@ class DatabricksHookTest(unittest.TestCase):
"""
Tests for DatabricksHook.
"""
+
@db.provide_session
def setUp(self, session=None):
conn = session.query(Connection) \
@@ -191,21 +191,19 @@ def test_init_bad_retry_limit(self):
DatabricksHook(retry_limit=0)
def test_do_api_call_retries_with_retryable_error(self):
- for exception in [
- requests_exceptions.ConnectionError,
- requests_exceptions.SSLError,
- requests_exceptions.Timeout,
- requests_exceptions.ConnectTimeout,
- requests_exceptions.HTTPError]:
- with mock.patch(
- 'airflow.contrib.hooks.databricks_hook.requests') as
mock_requests, \
- mock.patch.object(self.hook.log, 'error') as mock_errors:
- setup_mock_requests(mock_requests, exception)
-
- with self.assertRaises(AirflowException):
- self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
-
- self.assertEquals(mock_errors.call_count,
self.hook.retry_limit)
+ for exception in [requests_exceptions.ConnectionError,
+ requests_exceptions.SSLError,
+ requests_exceptions.Timeout,
+ requests_exceptions.ConnectTimeout,
+ requests_exceptions.HTTPError]:
+ with mock.patch('airflow.contrib.hooks.databricks_hook.requests')
as mock_requests:
+ with mock.patch.object(self.hook.log, 'error') as mock_errors:
+ setup_mock_requests(mock_requests, exception)
+
+ with self.assertRaises(AirflowException):
+ self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
+
+ self.assertEquals(mock_errors.call_count,
self.hook.retry_limit)
@mock.patch('airflow.contrib.hooks.databricks_hook.requests')
def test_do_api_call_does_not_retry_with_non_retryable_error(self,
mock_requests):
@@ -220,56 +218,52 @@ def
test_do_api_call_does_not_retry_with_non_retryable_error(self, mock_requests
mock_errors.assert_not_called()
def test_do_api_call_succeeds_after_retrying(self):
- for exception in [
- requests_exceptions.ConnectionError,
- requests_exceptions.SSLError,
- requests_exceptions.Timeout,
- requests_exceptions.ConnectTimeout,
- requests_exceptions.HTTPError]:
- with mock.patch(
- 'airflow.contrib.hooks.databricks_hook.requests') as
mock_requests, \
- mock.patch.object(self.hook.log, 'error') as mock_errors:
- setup_mock_requests(
- mock_requests,
- exception,
- error_count=2,
- response_content={'run_id': '1'}
- )
-
- response = self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
-
- self.assertEquals(mock_errors.call_count, 2)
- self.assertEquals(response, {'run_id': '1'})
+ for exception in [requests_exceptions.ConnectionError,
+ requests_exceptions.SSLError,
+ requests_exceptions.Timeout,
+ requests_exceptions.ConnectTimeout,
+ requests_exceptions.HTTPError]:
+ with mock.patch('airflow.contrib.hooks.databricks_hook.requests')
as mock_requests:
+ with mock.patch.object(self.hook.log, 'error') as mock_errors:
+ setup_mock_requests(
+ mock_requests,
+ exception,
+ error_count=2,
+ response_content={'run_id': '1'}
+ )
+
+ response = self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
+
+ self.assertEquals(mock_errors.call_count, 2)
+ self.assertEquals(response, {'run_id': '1'})
@mock.patch('airflow.contrib.hooks.databricks_hook.sleep')
def test_do_api_call_waits_between_retries(self, mock_sleep):
retry_delay = 5
self.hook = DatabricksHook(retry_delay=retry_delay)
- for exception in [
- requests_exceptions.ConnectionError,
- requests_exceptions.SSLError,
- requests_exceptions.Timeout,
- requests_exceptions.ConnectTimeout,
- requests_exceptions.HTTPError]:
- with mock.patch(
- 'airflow.contrib.hooks.databricks_hook.requests') as
mock_requests, \
- mock.patch.object(self.hook.log, 'error'):
- mock_sleep.reset_mock()
- setup_mock_requests(mock_requests, exception)
+ for exception in [requests_exceptions.ConnectionError,
+ requests_exceptions.SSLError,
+ requests_exceptions.Timeout,
+ requests_exceptions.ConnectTimeout,
+ requests_exceptions.HTTPError]:
+ with mock.patch('airflow.contrib.hooks.databricks_hook.requests')
as mock_requests:
+ with mock.patch.object(self.hook.log, 'error'):
+ mock_sleep.reset_mock()
+ setup_mock_requests(mock_requests, exception)
- with self.assertRaises(AirflowException):
- self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
+ with self.assertRaises(AirflowException):
+ self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
- self.assertEquals(len(mock_sleep.mock_calls),
self.hook.retry_limit - 1)
- mock_sleep.assert_called_with(retry_delay)
+ self.assertEquals(len(mock_sleep.mock_calls),
self.hook.retry_limit - 1)
+ mock_sleep.assert_called_with(retry_delay)
@mock.patch('airflow.contrib.hooks.databricks_hook.requests')
def test_submit_run(self, mock_requests):
mock_requests.post.return_value.json.return_value = {'run_id': '1'}
json = {
- 'notebook_task': NOTEBOOK_TASK,
- 'new_cluster': NEW_CLUSTER
+ 'notebook_task': NOTEBOOK_TASK,
+ 'new_cluster': NEW_CLUSTER
}
run_id = self.hook.submit_run(json)
@@ -407,6 +401,7 @@ class DatabricksHookTokenTest(unittest.TestCase):
"""
Tests for DatabricksHook when auth is done with token.
"""
+
@db.provide_session
def setUp(self, session=None):
conn = session.query(Connection) \
@@ -424,8 +419,8 @@ def test_submit_run(self, mock_requests):
status_code_mock = mock.PropertyMock(return_value=200)
type(mock_requests.post.return_value).status_code = status_code_mock
json = {
- 'notebook_task': NOTEBOOK_TASK,
- 'new_cluster': NEW_CLUSTER
+ 'notebook_task': NOTEBOOK_TASK,
+ 'new_cluster': NEW_CLUSTER
}
run_id = self.hook.submit_run(json)
diff --git a/tests/contrib/hooks/test_emr_hook.py
b/tests/contrib/hooks/test_emr_hook.py
index edb2dbb049..07c20e69cf 100644
--- a/tests/contrib/hooks/test_emr_hook.py
+++ b/tests/contrib/hooks/test_emr_hook.py
@@ -24,7 +24,6 @@
from airflow import configuration
from airflow.contrib.hooks.emr_hook import EmrHook
-
try:
from moto import mock_emr
except ImportError:
@@ -54,5 +53,6 @@ def
test_create_job_flow_uses_the_emr_config_to_create_a_cluster(self):
self.assertEqual(client.list_clusters()['Clusters'][0]['Id'],
cluster['JobFlowId'])
+
if __name__ == '__main__':
unittest.main()
diff --git a/tests/contrib/hooks/test_gcp_mlengine_hook.py
b/tests/contrib/hooks/test_gcp_mlengine_hook.py
index c3bc7a9c0d..f986354503 100644
--- a/tests/contrib/hooks/test_gcp_mlengine_hook.py
+++ b/tests/contrib/hooks/test_gcp_mlengine_hook.py
@@ -185,8 +185,7 @@ def test_list_versions(self):
self._SERVICE_URI_PREFIX, project, model_name), 'GET',
None),
] + [
- ('{}projects/{}/models/{}/versions?alt=json&pageToken={}'
- '&pageSize=100'.format(
+
('{}projects/{}/models/{}/versions?alt=json&pageToken={}&pageSize=100'.format(
self._SERVICE_URI_PREFIX, project, model_name, ix), 'GET',
None) for ix in range(len(versions) - 1)
]
diff --git a/tests/contrib/hooks/test_jdbc_hook.py
b/tests/contrib/hooks/test_jdbc_hook.py
index fd4a4fc337..3f708997d9 100644
--- a/tests/contrib/hooks/test_jdbc_hook.py
+++ b/tests/contrib/hooks/test_jdbc_hook.py
@@ -19,6 +19,7 @@
#
import unittest
+import json
from mock import Mock
from mock import patch
@@ -29,7 +30,7 @@
from airflow.utils import db
jdbc_conn_mock = Mock(
- name="jdbc_conn"
+ name="jdbc_conn"
)
@@ -37,10 +38,11 @@ class TestJdbcHook(unittest.TestCase):
def setUp(self):
configuration.load_test_config()
db.merge_conn(
- models.Connection(
- conn_id='jdbc_default', conn_type='jdbc',
- host='jdbc://localhost/', port=443,
- extra='{"extra__jdbc__drv_path":
"/path1/test.jar,/path2/t.jar2", "extra__jdbc__drv_clsname":
"com.driver.main"}'))
+ models.Connection(
+ conn_id='jdbc_default', conn_type='jdbc',
+ host='jdbc://localhost/', port=443,
+ extra=json.dumps({"extra__jdbc__drv_path":
"/path1/test.jar,/path2/t.jar2",
+ "extra__jdbc__drv_clsname":
"com.driver.main"})))
@patch("airflow.hooks.jdbc_hook.jaydebeapi.connect", autospec=True,
return_value=jdbc_conn_mock)
diff --git a/tests/contrib/hooks/test_jira_hook.py
b/tests/contrib/hooks/test_jira_hook.py
index 029a452990..378c379d55 100644
--- a/tests/contrib/hooks/test_jira_hook.py
+++ b/tests/contrib/hooks/test_jira_hook.py
@@ -29,7 +29,7 @@
from airflow.utils import db
jira_client_mock = Mock(
- name="jira_client"
+ name="jira_client"
)
@@ -37,10 +37,10 @@ class TestJiraHook(unittest.TestCase):
def setUp(self):
configuration.load_test_config()
db.merge_conn(
- models.Connection(
- conn_id='jira_default', conn_type='jira',
- host='https://localhost/jira/', port=443,
- extra='{"verify": "False", "project": "AIRFLOW"}'))
+ models.Connection(
+ conn_id='jira_default', conn_type='jira',
+ host='https://localhost/jira/', port=443,
+ extra='{"verify": "False", "project": "AIRFLOW"}'))
@patch("airflow.contrib.hooks.jira_hook.JIRA", autospec=True,
return_value=jira_client_mock)
diff --git a/tests/contrib/hooks/test_spark_sql_hook.py
b/tests/contrib/hooks/test_spark_sql_hook.py
index 47ccd618b3..f76768efcd 100644
--- a/tests/contrib/hooks/test_spark_sql_hook.py
+++ b/tests/contrib/hooks/test_spark_sql_hook.py
@@ -35,8 +35,8 @@ def get_after(sentinel, iterable):
next(truncated)
return next(truncated)
-class TestSparkSqlHook(unittest.TestCase):
+class TestSparkSqlHook(unittest.TestCase):
_config = {
'conn_id': 'spark_default',
'executor_cores': 4,
@@ -98,7 +98,8 @@ def test_spark_process_runcmd(self, mock_popen):
hook.run_query()
mock_debug.assert_called_with(
'Spark-Sql cmd: %s',
- ['spark-sql', '-e', 'SELECT 1', '--master', 'yarn',
'--name', 'default-name', '--verbose', '--queue', 'default']
+ ['spark-sql', '-e', 'SELECT 1', '--master', 'yarn',
'--name', 'default-name', '--verbose',
+ '--queue', 'default']
)
mock_info.assert_called_with(
'Spark-sql communicates using stdout'
@@ -107,7 +108,8 @@ def test_spark_process_runcmd(self, mock_popen):
# Then
self.assertEqual(
mock_popen.mock_calls[0],
- call(['spark-sql', '-e', 'SELECT 1', '--master', 'yarn', '--name',
'default-name', '--verbose', '--queue', 'default'], stderr=-2, stdout=-1)
+ call(['spark-sql', '-e', 'SELECT 1', '--master', 'yarn', '--name',
'default-name', '--verbose',
+ '--queue', 'default'], stderr=-2, stdout=-1)
)
diff --git a/tests/contrib/hooks/test_sqoop_hook.py
b/tests/contrib/hooks/test_sqoop_hook.py
index 649500a364..8bef5a4937 100644
--- a/tests/contrib/hooks/test_sqoop_hook.py
+++ b/tests/contrib/hooks/test_sqoop_hook.py
@@ -98,7 +98,8 @@ def test_popen(self, mock_popen):
mock_popen.return_value.stdout = StringIO(u'stdout')
mock_popen.return_value.stderr = StringIO(u'stderr')
mock_popen.return_value.returncode = 0
- mock_popen.return_value.communicate.return_value =
[StringIO(u'stdout\nstdout'), StringIO(u'stderr\nstderr')]
+ mock_popen.return_value.communicate.return_value = \
+ [StringIO(u'stdout\nstdout'), StringIO(u'stderr\nstderr')]
# When
hook = SqoopHook(conn_id='sqoop_test')
@@ -163,7 +164,7 @@ def test_submit(self):
self.assertIn("-files {}".format(self._config_json['files']), cmd)
if self._config_json['archives']:
- self.assertIn( "-archives
{}".format(self._config_json['archives']), cmd)
+ self.assertIn("-archives
{}".format(self._config_json['archives']), cmd)
self.assertIn("--hcatalog-database
{}".format(self._config['hcatalog_database']), cmd)
self.assertIn("--hcatalog-table
{}".format(self._config['hcatalog_table']), cmd)
@@ -173,7 +174,7 @@ def test_submit(self):
self.assertIn("--verbose", cmd)
if self._config['num_mappers']:
- self.assertIn( "--num-mappers
{}".format(self._config['num_mappers']), cmd)
+ self.assertIn("--num-mappers
{}".format(self._config['num_mappers']), cmd)
for key, value in self._config['properties'].items():
self.assertIn("-D {}={}".format(key, value), cmd)
@@ -301,7 +302,8 @@ def test_import_cmd(self):
def test_get_export_format_argument(self):
"""
- Tests to verify the hook get format function is building correct Sqoop
command with correct format type.
+ Tests to verify the hook get format function is building
+ correct Sqoop command with correct format type.
"""
hook = SqoopHook()
self.assertIn("--as-avrodatafile",
diff --git a/tests/contrib/operators/__init__.py
b/tests/contrib/operators/__init__.py
index 331c28ef9a..b7f8352944 100644
--- a/tests/contrib/operators/__init__.py
+++ b/tests/contrib/operators/__init__.py
@@ -17,4 +17,3 @@
# specific language governing permissions and limitations
# under the License.
#
-
diff --git a/tests/contrib/operators/test_awsbatch_operator.py
b/tests/contrib/operators/test_awsbatch_operator.py
index 273edd5630..4808574f23 100644
--- a/tests/contrib/operators/test_awsbatch_operator.py
+++ b/tests/contrib/operators/test_awsbatch_operator.py
@@ -33,7 +33,6 @@
except ImportError:
mock = None
-
RESPONSE_WITHOUT_FAILURES = {
"jobName": "51455483-c62c-48ac-9b88-53a6a725baa3",
"jobId": "8ba9d676-4108-4474-9dca-8bbac1da9b19"
@@ -58,7 +57,6 @@ def setUp(self, aws_hook_mock):
region_name='eu-west-1')
def test_init(self):
-
self.assertEqual(self.batch.job_name,
'51455483-c62c-48ac-9b88-53a6a725baa3')
self.assertEqual(self.batch.job_queue, 'queue')
self.assertEqual(self.batch.job_definition, 'hello-world')
@@ -76,13 +74,13 @@ def test_template_fields_overrides(self):
@mock.patch.object(AWSBatchOperator, '_wait_for_task_ended')
@mock.patch.object(AWSBatchOperator, '_check_success_task')
def test_execute_without_failures(self, check_mock, wait_mock):
-
client_mock =
self.aws_hook_mock.return_value.get_client_type.return_value
client_mock.submit_job.return_value = RESPONSE_WITHOUT_FAILURES
self.batch.execute(None)
-
self.aws_hook_mock.return_value.get_client_type.assert_called_once_with('batch',
region_name='eu-west-1')
+
self.aws_hook_mock.return_value.get_client_type.assert_called_once_with('batch',
+
region_name='eu-west-1')
client_mock.submit_job.assert_called_once_with(
jobQueue='queue',
jobName='51455483-c62c-48ac-9b88-53a6a725baa3',
@@ -95,14 +93,14 @@ def test_execute_without_failures(self, check_mock,
wait_mock):
self.assertEqual(self.batch.jobId,
'8ba9d676-4108-4474-9dca-8bbac1da9b19')
def test_execute_with_failures(self):
-
client_mock =
self.aws_hook_mock.return_value.get_client_type.return_value
client_mock.submit_job.return_value = ""
with self.assertRaises(AirflowException):
self.batch.execute(None)
-
self.aws_hook_mock.return_value.get_client_type.assert_called_once_with('batch',
region_name='eu-west-1')
+
self.aws_hook_mock.return_value.get_client_type.assert_called_once_with('batch',
+
region_name='eu-west-1')
client_mock.submit_job.assert_called_once_with(
jobQueue='queue',
jobName='51455483-c62c-48ac-9b88-53a6a725baa3',
@@ -111,7 +109,6 @@ def test_execute_with_failures(self):
)
def test_wait_end_tasks(self):
-
client_mock = mock.Mock()
self.batch.jobId = '8ba9d676-4108-4474-9dca-8bbac1da9b19'
self.batch.client = client_mock
diff --git a/tests/contrib/operators/test_databricks_operator.py
b/tests/contrib/operators/test_databricks_operator.py
index 75602efb7d..af62a3e4c3 100644
--- a/tests/contrib/operators/test_databricks_operator.py
+++ b/tests/contrib/operators/test_databricks_operator.py
@@ -132,9 +132,9 @@ def test_init_with_specified_run_name(self):
Test the initializer with a specified run_name.
"""
json = {
- 'new_cluster': NEW_CLUSTER,
- 'notebook_task': NOTEBOOK_TASK,
- 'run_name': RUN_NAME
+ 'new_cluster': NEW_CLUSTER,
+ 'notebook_task': NOTEBOOK_TASK,
+ 'run_name': RUN_NAME
}
op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
expected = databricks_operator._deep_string_coerce({
@@ -167,8 +167,8 @@ def test_init_with_merging(self):
def test_init_with_templating(self):
json = {
- 'new_cluster': NEW_CLUSTER,
- 'notebook_task': TEMPLATED_NOTEBOOK_TASK,
+ 'new_cluster': NEW_CLUSTER,
+ 'notebook_task': TEMPLATED_NOTEBOOK_TASK,
}
dag = DAG('test', start_date=datetime.now())
op = DatabricksSubmitRunOperator(dag=dag, task_id=TASK_ID, json=json)
@@ -196,8 +196,8 @@ def test_exec_success(self, db_mock_class):
Test the execute function in case where the run is successful.
"""
run = {
- 'new_cluster': NEW_CLUSTER,
- 'notebook_task': NOTEBOOK_TASK,
+ 'new_cluster': NEW_CLUSTER,
+ 'notebook_task': NOTEBOOK_TASK,
}
op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run)
db_mock = db_mock_class.return_value
@@ -227,8 +227,8 @@ def test_exec_failure(self, db_mock_class):
Test the execute function in case where the run failed.
"""
run = {
- 'new_cluster': NEW_CLUSTER,
- 'notebook_task': NOTEBOOK_TASK,
+ 'new_cluster': NEW_CLUSTER,
+ 'notebook_task': NOTEBOOK_TASK,
}
op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run)
db_mock = db_mock_class.return_value
diff --git a/tests/contrib/operators/test_dataproc_operator.py
b/tests/contrib/operators/test_dataproc_operator.py
index 60c1268ee7..7141a03140 100644
--- a/tests/contrib/operators/test_dataproc_operator.py
+++ b/tests/contrib/operators/test_dataproc_operator.py
@@ -114,7 +114,7 @@ def setUp(self):
self.dataproc_operators = []
self.mock_conn = Mock()
for labels in self.labels:
- self.dataproc_operators.append(
+ self.dataproc_operators.append(
DataprocClusterCreateOperator(
task_id=TASK_ID,
cluster_name=CLUSTER_NAME,
@@ -140,7 +140,7 @@ def setUp(self):
auto_delete_time=AUTO_DELETE_TIME,
auto_delete_ttl=AUTO_DELETE_TTL
)
- )
+ )
self.dag = DAG(
'test_dag',
default_args={
@@ -214,7 +214,7 @@ def test_build_cluster_data(self):
# set to the dataproc operator.
merged_labels = {}
merged_labels.update(self.labels[suffix])
- merged_labels.update({'airflow-version': 'v' +
version.replace('.', '-').replace('+','-')})
+ merged_labels.update({'airflow-version': 'v' +
version.replace('.', '-').replace('+', '-')})
self.assertTrue(re.match(r'[a-z]([-a-z0-9]*[a-z0-9])?',
cluster_data['labels']['airflow-version']))
self.assertEqual(cluster_data['labels'], merged_labels)
@@ -299,8 +299,7 @@ def test_init_with_custom_image(self):
expected_custom_image_url)
def test_cluster_name_log_no_sub(self):
- with patch('airflow.contrib.operators.dataproc_operator.DataProcHook')
\
- as mock_hook:
+ with patch('airflow.contrib.operators.dataproc_operator.DataProcHook')
as mock_hook:
mock_hook.return_value.get_conn = self.mock_conn
dataproc_task = DataprocClusterCreateOperator(
task_id=TASK_ID,
@@ -311,13 +310,12 @@ def test_cluster_name_log_no_sub(self):
dag=self.dag
)
with patch.object(dataproc_task.log, 'info') as mock_info:
- with self.assertRaises(TypeError) as _:
+ with self.assertRaises(TypeError):
dataproc_task.execute(None)
mock_info.assert_called_with('Creating cluster: %s',
CLUSTER_NAME)
def test_cluster_name_log_sub(self):
- with patch('airflow.contrib.operators.dataproc_operator.DataProcHook')
\
- as mock_hook:
+ with patch('airflow.contrib.operators.dataproc_operator.DataProcHook')
as mock_hook:
mock_hook.return_value.get_conn = self.mock_conn
dataproc_task = DataprocClusterCreateOperator(
task_id=TASK_ID,
@@ -336,13 +334,11 @@ def test_cluster_name_log_sub(self):
setattr(dataproc_task, 'cluster_name', rendered)
with self.assertRaises(TypeError):
dataproc_task.execute(None)
- mock_info.assert_called_with('Creating cluster: %s',
- u'smoke-cluster-testnodash')
+ mock_info.assert_called_with('Creating cluster: %s',
u'smoke-cluster-testnodash')
def test_build_cluster_data_internal_ip_only_without_subnetwork(self):
def create_cluster_with_invalid_internal_ip_only_setup():
-
# Given
create_cluster = DataprocClusterCreateOperator(
task_id=TASK_ID,
@@ -361,8 +357,7 @@ def create_cluster_with_invalid_internal_ip_only_setup():
create_cluster_with_invalid_internal_ip_only_setup()
self.assertEqual(str(cm.exception),
- "Set internal_ip_only to true only when"
- " you pass a subnetwork_uri.")
+ "Set internal_ip_only to true only when you pass a
subnetwork_uri.")
class DataprocClusterScaleOperatorTest(unittest.TestCase):
@@ -406,8 +401,7 @@ def test_cluster_name_log_no_sub(self):
mock_info.assert_called_with('Scaling cluster: %s',
CLUSTER_NAME)
def test_cluster_name_log_sub(self):
- with patch('airflow.contrib.operators.dataproc_operator.DataProcHook')
\
- as mock_hook:
+ with patch('airflow.contrib.operators.dataproc_operator.DataProcHook')
as mock_hook:
mock_hook.return_value.get_conn = self.mock_conn
dataproc_task = DataprocClusterScaleOperator(
task_id=TASK_ID,
@@ -427,22 +421,21 @@ def test_cluster_name_log_sub(self):
setattr(dataproc_task, 'cluster_name', rendered)
with self.assertRaises(TypeError):
dataproc_task.execute(None)
- mock_info.assert_called_with('Scaling cluster: %s',
- u'smoke-cluster-testnodash')
+ mock_info.assert_called_with('Scaling cluster: %s',
u'smoke-cluster-testnodash')
class DataprocClusterDeleteOperatorTest(unittest.TestCase):
# Unit test for the DataprocClusterDeleteOperator
def setUp(self):
self.mock_execute = Mock()
- self.mock_execute.execute = Mock(return_value={'done' : True})
+ self.mock_execute.execute = Mock(return_value={'done': True})
self.mock_get = Mock()
self.mock_get.get = Mock(return_value=self.mock_execute)
self.mock_operations = Mock()
self.mock_operations.get = Mock(return_value=self.mock_get)
self.mock_regions = Mock()
self.mock_regions.operations = Mock(return_value=self.mock_operations)
- self.mock_projects=Mock()
+ self.mock_projects = Mock()
self.mock_projects.regions = Mock(return_value=self.mock_regions)
self.mock_conn = Mock()
self.mock_conn.projects = Mock(return_value=self.mock_projects)
diff --git a/tests/contrib/operators/test_ecs_operator.py
b/tests/contrib/operators/test_ecs_operator.py
index 3b0f03351d..5f8c220260 100644
--- a/tests/contrib/operators/test_ecs_operator.py
+++ b/tests/contrib/operators/test_ecs_operator.py
@@ -34,14 +34,14 @@
except ImportError:
mock = None
-
RESPONSE_WITHOUT_FAILURES = {
"failures": [],
"tasks": [
{
"containers": [
{
- "containerArn":
"arn:aws:ecs:us-east-1:012345678910:container/e1ed7aac-d9b2-4315-8726-d2432bf11868",
+ "containerArn":
+
"arn:aws:ecs:us-east-1:012345678910:container/e1ed7aac-d9b2-4315-8726-d2432bf11868",
"lastStatus": "PENDING",
"name": "wordpress",
"taskArn":
"arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55"
@@ -85,7 +85,6 @@ def setUp(self, aws_hook_mock):
)
def test_init(self):
-
self.assertEqual(self.ecs.region_name, 'eu-west-1')
self.assertEqual(self.ecs.task_definition, 't')
self.assertEqual(self.ecs.aws_conn_id, None)
@@ -101,13 +100,13 @@ def test_template_fields_overrides(self):
@mock.patch.object(ECSOperator, '_wait_for_task_ended')
@mock.patch.object(ECSOperator, '_check_success_task')
def test_execute_without_failures(self, check_mock, wait_mock):
-
client_mock =
self.aws_hook_mock.return_value.get_client_type.return_value
client_mock.run_task.return_value = RESPONSE_WITHOUT_FAILURES
self.ecs.execute(None)
-
self.aws_hook_mock.return_value.get_client_type.assert_called_once_with('ecs',
region_name='eu-west-1')
+
self.aws_hook_mock.return_value.get_client_type.assert_called_once_with('ecs',
+
region_name='eu-west-1')
client_mock.run_task.assert_called_once_with(
cluster='c',
launchType='EC2',
@@ -131,10 +130,10 @@ def test_execute_without_failures(self, check_mock,
wait_mock):
wait_mock.assert_called_once_with()
check_mock.assert_called_once_with()
- self.assertEqual(self.ecs.arn,
'arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55')
+ self.assertEqual(self.ecs.arn,
+
'arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55')
def test_execute_with_failures(self):
-
client_mock =
self.aws_hook_mock.return_value.get_client_type.return_value
resp_failures = deepcopy(RESPONSE_WITHOUT_FAILURES)
resp_failures['failures'].append('dummy error')
@@ -143,7 +142,8 @@ def test_execute_with_failures(self):
with self.assertRaises(AirflowException):
self.ecs.execute(None)
-
self.aws_hook_mock.return_value.get_client_type.assert_called_once_with('ecs',
region_name='eu-west-1')
+
self.aws_hook_mock.return_value.get_client_type.assert_called_once_with('ecs',
+
region_name='eu-west-1')
client_mock.run_task.assert_called_once_with(
cluster='c',
launchType='EC2',
@@ -166,7 +166,6 @@ def test_execute_with_failures(self):
)
def test_wait_end_tasks(self):
-
client_mock = mock.Mock()
self.ecs.arn = 'arn'
self.ecs.client = client_mock
diff --git a/tests/contrib/operators/test_emr_terminate_job_flow_operator.py
b/tests/contrib/operators/test_emr_terminate_job_flow_operator.py
index a4d43407c7..d25b02adb5 100644
--- a/tests/contrib/operators/test_emr_terminate_job_flow_operator.py
+++ b/tests/contrib/operators/test_emr_terminate_job_flow_operator.py
@@ -44,10 +44,8 @@ def setUp(self):
# Mock out the emr_client creator
self.boto3_session_mock = MagicMock(return_value=mock_emr_session)
-
def test_execute_terminates_the_job_flow_and_does_not_error(self):
with patch('boto3.session.Session', self.boto3_session_mock):
-
operator = EmrTerminateJobFlowOperator(
task_id='test_task',
job_flow_id='j-8989898989',
@@ -56,5 +54,6 @@ def
test_execute_terminates_the_job_flow_and_does_not_error(self):
operator.execute(None)
+
if __name__ == '__main__':
unittest.main()
diff --git a/tests/contrib/operators/test_hipchat_operator.py
b/tests/contrib/operators/test_hipchat_operator.py
index 71c93c9259..dbfb227d1d 100644
--- a/tests/contrib/operators/test_hipchat_operator.py
+++ b/tests/contrib/operators/test_hipchat_operator.py
@@ -47,7 +47,7 @@ def test_execute(self, request_mock):
operator = HipChatAPISendRoomNotificationOperator(
task_id='test_hipchat_success',
- owner = 'airflow',
+ owner='airflow',
token='abc123',
room_id='room_id',
message='hello world!'
diff --git a/tests/contrib/operators/test_hive_to_dynamodb_operator.py
b/tests/contrib/operators/test_hive_to_dynamodb_operator.py
index 9ef3809593..ab86e05517 100644
--- a/tests/contrib/operators/test_hive_to_dynamodb_operator.py
+++ b/tests/contrib/operators/test_hive_to_dynamodb_operator.py
@@ -20,17 +20,18 @@
import json
import unittest
+import datetime
import mock
import pandas as pd
from airflow import configuration, DAG
-
-configuration.load_test_config()
-import datetime
from airflow.contrib.hooks.aws_dynamodb_hook import AwsDynamoDBHook
+
import airflow.contrib.operators.hive_to_dynamodb
+configuration.load_test_config()
+
DEFAULT_DATE = datetime.datetime(2015, 1, 1)
DEFAULT_DATE_ISO = DEFAULT_DATE.isoformat()
DEFAULT_DATE_DS = DEFAULT_DATE_ISO[:10]
@@ -67,9 +68,8 @@ def test_get_conn_returns_a_boto3_connection(self):
@unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamodb2 package not
present')
@mock_dynamodb2
def test_get_records_with_schema(self, get_results_mock):
-
# this table needs to be created in production
- table = self.hook.get_conn().create_table(
+ self.hook.get_conn().create_table(
TableName='test_airflow',
KeySchema=[
{
@@ -108,9 +108,8 @@ def test_get_records_with_schema(self, get_results_mock):
@unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamodb2 package not
present')
@mock_dynamodb2
def test_pre_process_records_with_schema(self, get_results_mock):
-
- # this table needs to be created in production
- table = self.hook.get_conn().create_table(
+ # this table needs to be created in production
+ self.hook.get_conn().create_table(
TableName='test_airflow',
KeySchema=[
{
@@ -141,8 +140,7 @@ def test_pre_process_records_with_schema(self,
get_results_mock):
operator.execute(None)
table = self.hook.get_conn().Table('test_airflow')
- table.meta.client.get_waiter(
- 'table_exists').wait(TableName='test_airflow')
+
table.meta.client.get_waiter('table_exists').wait(TableName='test_airflow')
self.assertEqual(table.item_count, 1)
diff --git a/tests/contrib/operators/test_jira_operator_test.py
b/tests/contrib/operators/test_jira_operator_test.py
index a358d3019f..2509038a36 100644
--- a/tests/contrib/operators/test_jira_operator_test.py
+++ b/tests/contrib/operators/test_jira_operator_test.py
@@ -31,7 +31,7 @@
DEFAULT_DATE = timezone.datetime(2017, 1, 1)
jira_client_mock = Mock(
- name="jira_client_for_test"
+ name="jira_client_for_test"
)
minimal_test_ticket = {
@@ -54,14 +54,14 @@ def setUp(self):
args = {
'owner': 'airflow',
'start_date': DEFAULT_DATE
- }
+ }
dag = DAG('test_dag_id', default_args=args)
self.dag = dag
db.merge_conn(
- models.Connection(
- conn_id='jira_default', conn_type='jira',
- host='https://localhost/jira/', port=443,
- extra='{"verify": "False", "project": "AIRFLOW"}'))
+ models.Connection(
+ conn_id='jira_default', conn_type='jira',
+ host='https://localhost/jira/', port=443,
+ extra='{"verify": "False", "project": "AIRFLOW"}'))
@patch("airflow.contrib.hooks.jira_hook.JIRA",
autospec=True, return_value=jira_client_mock)
diff --git a/tests/contrib/operators/test_mlengine_operator_utils.py
b/tests/contrib/operators/test_mlengine_operator_utils.py
index 09f0071e21..a072265041 100644
--- a/tests/contrib/operators/test_mlengine_operator_utils.py
+++ b/tests/contrib/operators/test_mlengine_operator_utils.py
@@ -36,7 +36,6 @@
class CreateEvaluateOpsTest(unittest.TestCase):
-
INPUT_MISSING_ORIGIN = {
'dataFormat': 'TEXT',
'inputPaths': ['gs://legal-bucket/fake-input-path/*'],
@@ -89,7 +88,6 @@ def testSuccessfulRun(self):
with patch('airflow.contrib.operators.mlengine_operator.'
'MLEngineHook') as mock_mlengine_hook:
-
success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy()
success_message['predictionInput'] = input_with_model
hook_instance = mock_mlengine_hook.return_value
@@ -107,7 +105,6 @@ def testSuccessfulRun(self):
with patch('airflow.contrib.operators.dataflow_operator.'
'DataFlowHook') as mock_dataflow_hook:
-
hook_instance = mock_dataflow_hook.return_value
hook_instance.start_python_dataflow.return_value = None
summary.execute(None)
@@ -126,7 +123,6 @@ def testSuccessfulRun(self):
with patch('airflow.contrib.operators.mlengine_operator_utils.'
'GoogleCloudStorageHook') as mock_gcs_hook:
-
hook_instance = mock_gcs_hook.return_value
hook_instance.download.return_value = '{"err": 0.9, "count": 9}'
result = validate.execute({})
@@ -159,27 +155,25 @@ def testFailures(self):
}
with self.assertRaisesRegexp(AirflowException, 'Missing model origin'):
- _ = create_evaluate_ops(**other_params_but_models)
+ create_evaluate_ops(**other_params_but_models)
with self.assertRaisesRegexp(AirflowException, 'Ambiguous model
origin'):
- _ = create_evaluate_ops(model_uri='abc', model_name='cde',
- **other_params_but_models)
+ create_evaluate_ops(model_uri='abc', model_name='cde',
**other_params_but_models)
with self.assertRaisesRegexp(AirflowException, 'Ambiguous model
origin'):
- _ = create_evaluate_ops(model_uri='abc', version_name='vvv',
- **other_params_but_models)
+ create_evaluate_ops(model_uri='abc', version_name='vvv',
**other_params_but_models)
with self.assertRaisesRegexp(AirflowException,
'`metric_fn` param must be callable'):
params = other_params_but_models.copy()
params['metric_fn_and_keys'] = (None, ['abc'])
- _ = create_evaluate_ops(model_uri='gs://blah', **params)
+ create_evaluate_ops(model_uri='gs://blah', **params)
with self.assertRaisesRegexp(AirflowException,
'`validate_fn` param must be callable'):
params = other_params_but_models.copy()
params['validate_fn'] = None
- _ = create_evaluate_ops(model_uri='gs://blah', **params)
+ create_evaluate_ops(model_uri='gs://blah', **params)
if __name__ == '__main__':
diff --git a/tests/contrib/operators/test_qubole_operator.py
b/tests/contrib/operators/test_qubole_operator.py
index bf61262bb3..c0894c0ba7 100644
--- a/tests/contrib/operators/test_qubole_operator.py
+++ b/tests/contrib/operators/test_qubole_operator.py
@@ -35,9 +35,9 @@
except ImportError:
mock = None
-DAG_ID="qubole_test_dag"
-TASK_ID="test_task"
-DEFAULT_CONN="qubole_default"
+DAG_ID = "qubole_test_dag"
+TASK_ID = "test_task"
+DEFAULT_CONN = "qubole_default"
TEMPLATE_CONN = "my_conn_id"
DEFAULT_DATE = datetime(2017, 1, 1)
@@ -60,7 +60,7 @@ def test_init_with_template_connection(self):
qubole_conn_id="{{
dag_run.conf['qubole_conn_id'] }}")
result = task.render_template('qubole_conn_id', "{{ qubole_conn_id }}",
- {'qubole_conn_id' : TEMPLATE_CONN})
+ {'qubole_conn_id': TEMPLATE_CONN})
self.assertEqual(task.task_id, TASK_ID)
self.assertEqual(result, TEMPLATE_CONN)
@@ -93,26 +93,22 @@ def test_hyphen_args_note_id(self):
dag = DAG(DAG_ID, start_date=DEFAULT_DATE)
with dag:
- task = QuboleOperator(task_id=TASK_ID, command_type='sparkcmd',
- note_id="123", dag=dag)
-
self.assertEqual(task.get_hook().create_cmd_args({'run_id':'dummy'})[0],
- "--note-id=123")
+ task = QuboleOperator(task_id=TASK_ID, command_type='sparkcmd',
note_id="123", dag=dag)
+
+ self.assertEqual(task.get_hook().create_cmd_args({'run_id':
'dummy'})[0], "--note-id=123")
def test_position_args_parameters(self):
dag = DAG(DAG_ID, start_date=DEFAULT_DATE)
with dag:
task = QuboleOperator(task_id=TASK_ID, command_type='pigcmd',
- parameters="key1=value1 key2=value2", dag=dag)
+ parameters="key1=value1 key2=value2",
dag=dag)
-
self.assertEqual(task.get_hook().create_cmd_args({'run_id':'dummy'})[1],
- "key1=value1")
-
self.assertEqual(task.get_hook().create_cmd_args({'run_id':'dummy'})[2],
- "key2=value2")
+ self.assertEqual(task.get_hook().create_cmd_args({'run_id':
'dummy'})[1], "key1=value1")
+ self.assertEqual(task.get_hook().create_cmd_args({'run_id':
'dummy'})[2], "key2=value2")
- task = QuboleOperator(task_id=TASK_ID, command_type='hadoopcmd',
- sub_command="s3distcp --src
s3n://airflow/source_hadoopcmd " +
- "--dest
s3n://airflow/destination_hadoopcmd", dag=dag)
+ cmd = "s3distcp --src s3n://airflow/source_hadoopcmd --dest
s3n://airflow/destination_hadoopcmd"
+ task = QuboleOperator(task_id=TASK_ID, command_type='hadoopcmd',
dag=dag, sub_command=cmd)
self.assertEqual(task.get_hook().create_cmd_args({'run_id':
'dummy'})[1],
"s3distcp")
@@ -124,5 +120,3 @@ def test_position_args_parameters(self):
"--dest")
self.assertEqual(task.get_hook().create_cmd_args({'run_id':
'dummy'})[5],
"s3n://airflow/destination_hadoopcmd")
-
-
diff --git a/tests/contrib/operators/test_sftp_operator.py
b/tests/contrib/operators/test_sftp_operator.py
index bf4525e311..7a450c0844 100644
--- a/tests/contrib/operators/test_sftp_operator.py
+++ b/tests/contrib/operators/test_sftp_operator.py
@@ -42,6 +42,7 @@ def reset(dag_id=TEST_DAG_ID):
session.commit()
session.close()
+
reset()
@@ -79,12 +80,12 @@ def test_pickle_file_transfer_put(self):
# put test file to remote
put_test_task = SFTPOperator(
- task_id="test_sftp",
- ssh_hook=self.hook,
- local_filepath=self.test_local_filepath,
- remote_filepath=self.test_remote_filepath,
- operation=SFTPOperation.PUT,
- dag=self.dag
+ task_id="test_sftp",
+ ssh_hook=self.hook,
+ local_filepath=self.test_local_filepath,
+ remote_filepath=self.test_remote_filepath,
+ operation=SFTPOperation.PUT,
+ dag=self.dag
)
self.assertIsNotNone(put_test_task)
ti2 = TaskInstance(task=put_test_task,
execution_date=timezone.utcnow())
@@ -92,18 +93,18 @@ def test_pickle_file_transfer_put(self):
# check the remote file content
check_file_task = SSHOperator(
- task_id="test_check_file",
- ssh_hook=self.hook,
- command="cat {0}".format(self.test_remote_filepath),
- do_xcom_push=True,
- dag=self.dag
+ task_id="test_check_file",
+ ssh_hook=self.hook,
+ command="cat {0}".format(self.test_remote_filepath),
+ do_xcom_push=True,
+ dag=self.dag
)
self.assertIsNotNone(check_file_task)
ti3 = TaskInstance(task=check_file_task,
execution_date=timezone.utcnow())
ti3.run()
self.assertEqual(
- ti3.xcom_pull(task_ids='test_check_file',
key='return_value').strip(),
- test_local_file_content)
+ ti3.xcom_pull(task_ids='test_check_file',
key='return_value').strip(),
+ test_local_file_content)
def test_json_file_transfer_put(self):
configuration.conf.set("core", "enable_xcom_pickling", "False")
@@ -116,12 +117,12 @@ def test_json_file_transfer_put(self):
# put test file to remote
put_test_task = SFTPOperator(
- task_id="test_sftp",
- ssh_hook=self.hook,
- local_filepath=self.test_local_filepath,
- remote_filepath=self.test_remote_filepath,
- operation=SFTPOperation.PUT,
- dag=self.dag
+ task_id="test_sftp",
+ ssh_hook=self.hook,
+ local_filepath=self.test_local_filepath,
+ remote_filepath=self.test_remote_filepath,
+ operation=SFTPOperation.PUT,
+ dag=self.dag
)
self.assertIsNotNone(put_test_task)
ti2 = TaskInstance(task=put_test_task,
execution_date=timezone.utcnow())
@@ -129,19 +130,18 @@ def test_json_file_transfer_put(self):
# check the remote file content
check_file_task = SSHOperator(
- task_id="test_check_file",
- ssh_hook=self.hook,
- command="cat {0}".format(self.test_remote_filepath),
- do_xcom_push=True,
- dag=self.dag
+ task_id="test_check_file",
+ ssh_hook=self.hook,
+ command="cat {0}".format(self.test_remote_filepath),
+ do_xcom_push=True,
+ dag=self.dag
)
self.assertIsNotNone(check_file_task)
ti3 = TaskInstance(task=check_file_task,
execution_date=timezone.utcnow())
ti3.run()
self.assertEqual(
- ti3.xcom_pull(task_ids='test_check_file',
key='return_value').strip(),
- b64encode(test_local_file_content).decode('utf-8'))
-
+ ti3.xcom_pull(task_ids='test_check_file',
key='return_value').strip(),
+ b64encode(test_local_file_content).decode('utf-8'))
def test_pickle_file_transfer_get(self):
configuration.conf.set("core", "enable_xcom_pickling", "True")
@@ -151,12 +151,12 @@ def test_pickle_file_transfer_get(self):
# create a test file remotely
create_file_task = SSHOperator(
- task_id="test_create_file",
- ssh_hook=self.hook,
- command="echo '{0}' > {1}".format(test_remote_file_content,
- self.test_remote_filepath),
- do_xcom_push=True,
- dag=self.dag
+ task_id="test_create_file",
+ ssh_hook=self.hook,
+ command="echo '{0}' > {1}".format(test_remote_file_content,
+ self.test_remote_filepath),
+ do_xcom_push=True,
+ dag=self.dag
)
self.assertIsNotNone(create_file_task)
ti1 = TaskInstance(task=create_file_task,
execution_date=timezone.utcnow())
@@ -164,12 +164,12 @@ def test_pickle_file_transfer_get(self):
# get remote file to local
get_test_task = SFTPOperator(
- task_id="test_sftp",
- ssh_hook=self.hook,
- local_filepath=self.test_local_filepath,
- remote_filepath=self.test_remote_filepath,
- operation=SFTPOperation.GET,
- dag=self.dag
+ task_id="test_sftp",
+ ssh_hook=self.hook,
+ local_filepath=self.test_local_filepath,
+ remote_filepath=self.test_remote_filepath,
+ operation=SFTPOperation.GET,
+ dag=self.dag
)
self.assertIsNotNone(get_test_task)
ti2 = TaskInstance(task=get_test_task,
execution_date=timezone.utcnow())
@@ -189,12 +189,12 @@ def test_json_file_transfer_get(self):
# create a test file remotely
create_file_task = SSHOperator(
- task_id="test_create_file",
- ssh_hook=self.hook,
- command="echo '{0}' > {1}".format(test_remote_file_content,
- self.test_remote_filepath),
- do_xcom_push=True,
- dag=self.dag
+ task_id="test_create_file",
+ ssh_hook=self.hook,
+ command="echo '{0}' > {1}".format(test_remote_file_content,
+ self.test_remote_filepath),
+ do_xcom_push=True,
+ dag=self.dag
)
self.assertIsNotNone(create_file_task)
ti1 = TaskInstance(task=create_file_task,
execution_date=timezone.utcnow())
@@ -202,12 +202,12 @@ def test_json_file_transfer_get(self):
# get remote file to local
get_test_task = SFTPOperator(
- task_id="test_sftp",
- ssh_hook=self.hook,
- local_filepath=self.test_local_filepath,
- remote_filepath=self.test_remote_filepath,
- operation=SFTPOperation.GET,
- dag=self.dag
+ task_id="test_sftp",
+ ssh_hook=self.hook,
+ local_filepath=self.test_local_filepath,
+ remote_filepath=self.test_remote_filepath,
+ operation=SFTPOperation.GET,
+ dag=self.dag
)
self.assertIsNotNone(get_test_task)
ti2 = TaskInstance(task=get_test_task,
execution_date=timezone.utcnow())
@@ -218,7 +218,7 @@ def test_json_file_transfer_get(self):
with open(self.test_local_filepath, 'r') as f:
content_received = f.read()
self.assertEqual(content_received.strip(),
- test_remote_file_content.encode('utf-8').decode('utf-8'))
+
test_remote_file_content.encode('utf-8').decode('utf-8'))
def test_arg_checking(self):
from airflow.exceptions import AirflowException
diff --git a/tests/contrib/operators/test_spark_sql_operator.py
b/tests/contrib/operators/test_spark_sql_operator.py
index b0c956931f..bbe6868dc7 100644
--- a/tests/contrib/operators/test_spark_sql_operator.py
+++ b/tests/contrib/operators/test_spark_sql_operator.py
@@ -28,7 +28,6 @@
class TestSparkSqlOperator(unittest.TestCase):
-
_config = {
'sql': 'SELECT 22',
'conn_id': 'spark_special_conn_id',
@@ -74,5 +73,6 @@ def test_execute(self):
self.assertEqual(self._config['num_executors'],
operator._num_executors)
self.assertEqual(self._config['yarn_queue'], operator._yarn_queue)
+
if __name__ == '__main__':
unittest.main()
diff --git a/tests/contrib/operators/test_sqoop_operator.py
b/tests/contrib/operators/test_sqoop_operator.py
index 5a235c487f..e3cf32fdae 100644
--- a/tests/contrib/operators/test_sqoop_operator.py
+++ b/tests/contrib/operators/test_sqoop_operator.py
@@ -111,7 +111,7 @@ def test_execute(self):
self.assertEqual(self._config['extra_export_options'],
operator.extra_export_options)
# the following are meant to be more of examples
- sqoop_import_op = SqoopOperator(
+ SqoopOperator(
task_id='sqoop_import_using_table',
cmd_type='import',
conn_id='sqoop_default',
@@ -125,12 +125,13 @@ def test_execute(self):
dag=self.dag
)
- sqoop_import_op_qry = SqoopOperator(
+ SqoopOperator(
task_id='sqoop_import_using_query',
cmd_type='import',
conn_id='sqoop_default',
query='select name, age from company where $CONDITIONS',
- split_by='age', # the mappers will pass in values to the
$CONDITIONS based on the field you select to split by
+ split_by='age',
+ # the mappers will pass in values to the $CONDITIONS based on the
field you select to split by
verbose=True,
num_mappers=None,
hcatalog_database='default',
@@ -140,7 +141,7 @@ def test_execute(self):
dag=self.dag
)
- sqoop_import_op_with_partition = SqoopOperator(
+ SqoopOperator(
task_id='sqoop_import_with_partition',
cmd_type='import',
conn_id='sqoop_default',
@@ -157,7 +158,7 @@ def test_execute(self):
dag=self.dag
)
- sqoop_export_op_name = SqoopOperator(
+ SqoopOperator(
task_id='sqoop_export_tablename',
cmd_type='export',
conn_id='sqoop_default',
@@ -170,7 +171,7 @@ def test_execute(self):
dag=self.dag
)
- sqoop_export_op_path = SqoopOperator(
+ SqoopOperator(
task_id='sqoop_export_tablepath',
cmd_type='export',
conn_id='sqoop_default',
diff --git a/tests/contrib/operators/test_ssh_operator.py
b/tests/contrib/operators/test_ssh_operator.py
index 1a2c788596..00c4be679b 100644
--- a/tests/contrib/operators/test_ssh_operator.py
+++ b/tests/contrib/operators/test_ssh_operator.py
@@ -40,6 +40,7 @@ def reset(dag_id=TEST_DAG_ID):
session.commit()
session.close()
+
reset()
@@ -79,17 +80,17 @@ def test_hook_created_correctly(self):
def test_json_command_execution(self):
configuration.conf.set("core", "enable_xcom_pickling", "False")
task = SSHOperator(
- task_id="test",
- ssh_hook=self.hook,
- command="echo -n airflow",
- do_xcom_push=True,
- dag=self.dag,
+ task_id="test",
+ ssh_hook=self.hook,
+ command="echo -n airflow",
+ do_xcom_push=True,
+ dag=self.dag,
)
self.assertIsNotNone(task)
ti = TaskInstance(
- task=task, execution_date=timezone.utcnow())
+ task=task, execution_date=timezone.utcnow())
ti.run()
self.assertIsNotNone(ti.duration)
self.assertEqual(ti.xcom_pull(task_ids='test', key='return_value'),
@@ -98,17 +99,17 @@ def test_json_command_execution(self):
def test_pickle_command_execution(self):
configuration.conf.set("core", "enable_xcom_pickling", "True")
task = SSHOperator(
- task_id="test",
- ssh_hook=self.hook,
- command="echo -n airflow",
- do_xcom_push=True,
- dag=self.dag,
+ task_id="test",
+ ssh_hook=self.hook,
+ command="echo -n airflow",
+ do_xcom_push=True,
+ dag=self.dag,
)
self.assertIsNotNone(task)
ti = TaskInstance(
- task=task, execution_date=timezone.utcnow())
+ task=task, execution_date=timezone.utcnow())
ti.run()
self.assertIsNotNone(ti.duration)
self.assertEqual(ti.xcom_pull(task_ids='test', key='return_value'),
b'airflow')
diff --git a/tests/contrib/operators/test_vertica_to_mysql.py
b/tests/contrib/operators/test_vertica_to_mysql.py
index 615f111f84..c76c4d5dcc 100644
--- a/tests/contrib/operators/test_vertica_to_mysql.py
+++ b/tests/contrib/operators/test_vertica_to_mysql.py
@@ -30,14 +30,14 @@ def mock_get_conn():
commit_mock = mock.MagicMock(
)
cursor_mock = mock.MagicMock(
- execute = [],
- fetchall = [['1', '2', '3']],
- description = ['a', 'b', 'c'],
- iterate = [['1', '2', '3']],
+ execute=[],
+ fetchall=[['1', '2', '3']],
+ description=['a', 'b', 'c'],
+ iterate=[['1', '2', '3']],
)
conn_mock = mock.MagicMock(
- commit = commit_mock,
- cursor = cursor_mock,
+ commit=commit_mock,
+ cursor=cursor_mock,
)
return conn_mock
diff --git a/tests/contrib/sensors/test_emr_job_flow_sensor.py
b/tests/contrib/sensors/test_emr_job_flow_sensor.py
index 606cd84c32..5b33cb5bb4 100644
--- a/tests/contrib/sensors/test_emr_job_flow_sensor.py
+++ b/tests/contrib/sensors/test_emr_job_flow_sensor.py
@@ -42,7 +42,8 @@
'Status': {
'State': 'STARTING',
'StateChangeReason': {},
- 'Timeline': {'CreationDateTime': datetime.datetime(2016, 6, 27,
21, 5, 2, 348000, tzinfo=tzlocal())}
+ 'Timeline': {
+ 'CreationDateTime': datetime.datetime(2016, 6, 27, 21, 5, 2,
348000, tzinfo=tzlocal())}
},
'Tags': [
{'Key': 'app', 'Value': 'analytics'},
@@ -74,7 +75,8 @@
'Status': {
'State': 'TERMINATED',
'StateChangeReason': {},
- 'Timeline': {'CreationDateTime': datetime.datetime(2016, 6, 27,
21, 5, 2, 348000, tzinfo=tzlocal())}
+ 'Timeline': {
+ 'CreationDateTime': datetime.datetime(2016, 6, 27, 21, 5, 2,
348000, tzinfo=tzlocal())}
},
'Tags': [
{'Key': 'app', 'Value': 'analytics'},
@@ -107,10 +109,8 @@ def setUp(self):
# Mock out the emr_client creator
self.boto3_session_mock = MagicMock(return_value=mock_emr_session)
-
def
test_execute_calls_with_the_job_flow_id_until_it_reaches_a_terminal_state(self):
with patch('boto3.session.Session', self.boto3_session_mock):
-
operator = EmrJobFlowSensor(
task_id='test_task',
poke_interval=2,
diff --git a/tests/contrib/sensors/test_jira_sensor_test.py
b/tests/contrib/sensors/test_jira_sensor_test.py
index 561b4babd1..32aa235851 100644
--- a/tests/contrib/sensors/test_jira_sensor_test.py
+++ b/tests/contrib/sensors/test_jira_sensor_test.py
@@ -67,13 +67,13 @@ def setUp(self):
def test_issue_label_set(self, jira_mock):
jira_mock.return_value.issue.return_value = minimal_test_ticket
- ticket_label_sensor = JiraTicketSensor(task_id='search-ticket-test',
- ticket_id='TEST-1226',
- field_checker_func=
-
TestJiraSensor.field_checker_func,
- timeout=518400,
- poke_interval=10,
- dag=self.dag)
+ ticket_label_sensor = JiraTicketSensor(
+ task_id='search-ticket-test',
+ ticket_id='TEST-1226',
+ field_checker_func=TestJiraSensor.field_checker_func,
+ timeout=518400,
+ poke_interval=10,
+ dag=self.dag)
ticket_label_sensor.run(start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE, ignore_ti_state=True)
diff --git a/tests/contrib/utils/__init__.py b/tests/contrib/utils/__init__.py
index 331c28ef9a..b7f8352944 100644
--- a/tests/contrib/utils/__init__.py
+++ b/tests/contrib/utils/__init__.py
@@ -17,4 +17,3 @@
# specific language governing permissions and limitations
# under the License.
#
-
diff --git a/tests/core.py b/tests/core.py
index 918e9b4d49..91e2b069f5 100644
--- a/tests/core.py
+++ b/tests/core.py
@@ -917,11 +917,11 @@ def test_task_fail_duration(self):
session = settings.Session()
try:
p.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
ignore_ti_state=True)
- except:
+ except Exception:
pass
try:
f.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
ignore_ti_state=True)
- except:
+ except Exception:
pass
p_fails = session.query(models.TaskFail).filter_by(
task_id='pass_sleepy',
@@ -931,10 +931,9 @@ def test_task_fail_duration(self):
task_id='fail_sleepy',
dag_id=self.dag.dag_id,
execution_date=DEFAULT_DATE).all()
- print(f_fails)
+
self.assertEqual(0, len(p_fails))
self.assertEqual(1, len(f_fails))
- # C
self.assertGreaterEqual(sum([f.duration for f in f_fails]), 3)
def test_dag_stats(self):
@@ -1624,7 +1623,7 @@ def _wait_pidfile(self, pidfile):
try:
with open(pidfile) as f:
return int(f.read())
- except:
+ except Exception:
sleep(1)
def test_cli_webserver_foreground(self):
@@ -1734,7 +1733,7 @@ def test_csrf_acceptance(self):
def test_xss(self):
try:
self.app.get("/admin/airflow/tree?dag_id=<script>alert(123456)</script>")
- except:
+ except Exception:
# exception is expected here since dag doesnt exist
pass
response = self.app.get("/admin/log", follow_redirects=True)
@@ -2185,7 +2184,7 @@ def setUp(self):
configuration.conf.set("webserver", "auth_backend",
"airflow.contrib.auth.backends.ldap_auth")
try:
configuration.conf.add_section("ldap")
- except:
+ except Exception:
pass
configuration.conf.set("ldap", "uri", "ldap://openldap:389")
configuration.conf.set("ldap", "user_filter", "objectClass=*")
@@ -2272,7 +2271,7 @@ def setUp(self):
configuration.conf.set("webserver", "auth_backend",
"airflow.contrib.auth.backends.ldap_auth")
try:
configuration.conf.add_section("ldap")
- except:
+ except Exception:
pass
configuration.conf.set("ldap", "uri", "ldap://openldap:389")
configuration.conf.set("ldap", "user_filter", "objectClass=*")
@@ -2611,6 +2610,7 @@ def test_get_ha_client(self, mock_get_connections):
client = HDFSHook().get_conn()
self.assertIsInstance(client, snakebite.client.HAClient)
+
send_email_test = mock.Mock()
diff --git a/tests/dags/test_cli_triggered_dags.py
b/tests/dags/test_cli_triggered_dags.py
index f625e18475..9f53ca4c3a 100644
--- a/tests/dags/test_cli_triggered_dags.py
+++ b/tests/dags/test_cli_triggered_dags.py
@@ -42,14 +42,15 @@ def success(ti=None, *args, **kwargs):
# DAG tests that tasks ignore all dependencies
-dag1 = DAG(dag_id='test_run_ignores_all_dependencies',
default_args=dict(depends_on_past=True, **default_args))
+dag1 = DAG(dag_id='test_run_ignores_all_dependencies',
+ default_args=dict(depends_on_past=True, **default_args))
dag1_task1 = PythonOperator(
task_id='test_run_dependency_task',
python_callable=fail,
- dag=dag1,)
+ dag=dag1)
dag1_task2 = PythonOperator(
task_id='test_run_dependent_task',
python_callable=success,
provide_context=True,
- dag=dag1,)
+ dag=dag1)
dag1_task1.set_downstream(dag1_task2)
diff --git a/tests/dags/test_example_bash_operator.py
b/tests/dags/test_example_bash_operator.py
index f9bd6c7863..a87db8dd7c 100644
--- a/tests/dags/test_example_bash_operator.py
+++ b/tests/dags/test_example_bash_operator.py
@@ -46,7 +46,7 @@
for i in range(3):
i = str(i)
task = BashOperator(
- task_id='runme_'+i,
+ task_id='runme_' + i,
bash_command='echo "{{ task_instance_key_str }}" && sleep 1',
dag=dag)
task.set_downstream(run_this)
diff --git a/tests/dags/test_issue_1225.py b/tests/dags/test_issue_1225.py
index 8009f48904..0450cf470f 100644
--- a/tests/dags/test_issue_1225.py
+++ b/tests/dags/test_issue_1225.py
@@ -37,9 +37,11 @@
start_date=DEFAULT_DATE,
owner='airflow')
+
def fail():
raise ValueError('Expected failure.')
+
def delayed_fail():
"""
Delayed failure to make sure that processes are running before the error
@@ -50,6 +52,7 @@ def delayed_fail():
time.sleep(5)
raise ValueError('Expected failure.')
+
# DAG tests backfill with pooled tasks
# Previously backfill would queue the task but never run it
dag1 = DAG(dag_id='test_backfill_pooled_task_dag', default_args=default_args)
diff --git a/tests/dags/test_retry_handling_job.py
b/tests/dags/test_retry_handling_job.py
index 39c29a8d61..d8e314dbab 100644
--- a/tests/dags/test_retry_handling_job.py
+++ b/tests/dags/test_retry_handling_job.py
@@ -24,7 +24,7 @@
default_args = {
'owner': 'airflow',
'depends_on_past': False,
- 'start_date': datetime(2016,10,5,19),
+ 'start_date': datetime(2016, 10, 5, 19),
'email': ['[email protected]'],
'email_on_failure': False,
'email_on_retry': False,
@@ -38,4 +38,3 @@
task_id='test_retry_handling_op',
bash_command='exit 1',
dag=dag)
-
diff --git a/tests/executors/test_base_executor.py
b/tests/executors/test_base_executor.py
index f640a75e01..d3032cd640 100644
--- a/tests/executors/test_base_executor.py
+++ b/tests/executors/test_base_executor.py
@@ -42,4 +42,3 @@ def test_get_event_buffer(self):
self.assertEqual(len(executor.get_event_buffer(("my_dag1",))), 1)
self.assertEqual(len(executor.get_event_buffer()), 2)
self.assertEqual(len(executor.event_buffer), 0)
-
diff --git a/tests/executors/test_celery_executor.py
b/tests/executors/test_celery_executor.py
index f1b6a429fa..380201d30a 100644
--- a/tests/executors/test_celery_executor.py
+++ b/tests/executors/test_celery_executor.py
@@ -48,7 +48,7 @@ def test_celery_integration(self):
# errors are propagated for some reason
try:
executor.execute_async(key='fail', command=fail_command)
- except:
+ except Exception:
pass
executor.running['success'] = True
executor.running['fail'] = True
diff --git a/tests/executors/test_executor.py b/tests/executors/test_executor.py
index 23dd8d691c..aab66644b8 100644
--- a/tests/executors/test_executor.py
+++ b/tests/executors/test_executor.py
@@ -26,6 +26,7 @@ class TestExecutor(BaseExecutor):
"""
TestExecutor is used for unit testing purposes.
"""
+
def __init__(self, do_update=False, *args, **kwargs):
self.do_update = do_update
self._running = []
@@ -58,4 +59,3 @@ def terminate(self):
def end(self):
self.sync()
-
diff --git a/tests/executors/test_local_executor.py
b/tests/executors/test_local_executor.py
index 59cb09c74e..2a29ee2cd5 100644
--- a/tests/executors/test_local_executor.py
+++ b/tests/executors/test_local_executor.py
@@ -44,7 +44,7 @@ def execution_parallelism(self, parallelism=0):
# errors are propagated for some reason
try:
executor.execute_async(key='fail', command=fail_command)
- except:
+ except Exception:
pass
executor.running['fail'] = True
diff --git a/tests/hooks/__init__.py b/tests/hooks/__init__.py
index 4067cc78ee..114d189da1 100644
--- a/tests/hooks/__init__.py
+++ b/tests/hooks/__init__.py
@@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
diff --git a/tests/hooks/test_docker_hook.py b/tests/hooks/test_docker_hook.py
index b8a0132e50..dd7ed4d44d 100644
--- a/tests/hooks/test_docker_hook.py
+++ b/tests/hooks/test_docker_hook.py
@@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -101,7 +101,7 @@ def test_get_conn_with_standard_config(self, _):
)
client = hook.get_conn()
self.assertIsNotNone(client)
- except:
+ except Exception:
self.fail('Could not get connection from Airflow')
def test_get_conn_with_extra_config(self, _):
@@ -113,7 +113,7 @@ def test_get_conn_with_extra_config(self, _):
)
client = hook.get_conn()
self.assertIsNotNone(client)
- except:
+ except Exception:
self.fail('Could not get connection from Airflow')
def test_conn_with_standard_config_passes_parameters(self, _):
@@ -157,7 +157,7 @@ def
test_conn_with_broken_config_missing_username_fails(self, _):
)
)
with self.assertRaises(AirflowException):
- hook = DockerHook(
+ DockerHook(
docker_conn_id='docker_without_username',
base_url='unix://var/run/docker.sock',
version='auto'
@@ -173,7 +173,7 @@ def test_conn_with_broken_config_missing_host_fails(self,
_):
)
)
with self.assertRaises(AirflowException):
- hook = DockerHook(
+ DockerHook(
docker_conn_id='docker_without_host',
base_url='unix://var/run/docker.sock',
version='auto'
diff --git a/tests/hooks/test_postgres_hook.py
b/tests/hooks/test_postgres_hook.py
index 3e71f60f5b..f937c26782 100644
--- a/tests/hooks/test_postgres_hook.py
+++ b/tests/hooks/test_postgres_hook.py
@@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
diff --git a/tests/hooks/test_slack_hook.py b/tests/hooks/test_slack_hook.py
index 46f8946411..73193dff8c 100644
--- a/tests/hooks/test_slack_hook.py
+++ b/tests/hooks/test_slack_hook.py
@@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
diff --git a/tests/jobs.py b/tests/jobs.py
index 9dcd15fbe6..9b265724b6 100644
--- a/tests/jobs.py
+++ b/tests/jobs.py
@@ -1294,7 +1294,7 @@ def test_process_executor_events(self):
dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE)
dag2 = DAG(dag_id=dag_id2, start_date=DEFAULT_DATE)
task1 = DummyOperator(dag=dag, task_id=task_id_1)
- task2 = DummyOperator(dag=dag2, task_id=task_id_1)
+ DummyOperator(dag=dag2, task_id=task_id_1)
dagbag1 = self._make_simple_dag_bag([dag])
dagbag2 = self._make_simple_dag_bag([dag2])
@@ -1470,7 +1470,7 @@ def test_find_executable_task_instances_pool(self):
TI(task2, dr1.execution_date),
TI(task1, dr2.execution_date),
TI(task2, dr2.execution_date)
- ])
+ ])
for ti in tis:
ti.state = State.SCHEDULED
session.merge(ti)
@@ -1849,8 +1849,12 @@ def test_execute_task_instances(self):
session.commit()
self.assertEqual(State.RUNNING, dr1.state)
- self.assertEqual(2, DAG.get_num_task_instances(dag_id, dag.task_ids,
- states=[State.RUNNING], session=session))
+ self.assertEqual(
+ 2,
+ DAG.get_num_task_instances(
+ dag_id, dag.task_ids, states=[State.RUNNING], session=session
+ )
+ )
# create second dag run
dr2 = scheduler.create_dag_run(dag)
@@ -1874,8 +1878,12 @@ def test_execute_task_instances(self):
ti2.refresh_from_db()
ti3.refresh_from_db()
ti4.refresh_from_db()
- self.assertEqual(3, DAG.get_num_task_instances(dag_id, dag.task_ids,
- states=[State.RUNNING, State.QUEUED], session=session))
+ self.assertEqual(
+ 3,
+ DAG.get_num_task_instances(
+ dag_id, dag.task_ids, states=[State.RUNNING, State.QUEUED],
session=session
+ )
+ )
self.assertEqual(State.RUNNING, ti1.state)
self.assertEqual(State.RUNNING, ti2.state)
six.assertCountEqual(self, [State.QUEUED, State.SCHEDULED],
[ti3.state, ti4.state])
@@ -1922,43 +1930,26 @@ def test_execute_task_instances_limit(self):
@unittest.skipUnless("INTEGRATION" in os.environ,
"The test is flaky with nondeterministic result")
def test_change_state_for_tis_without_dagrun(self):
- dag1 = DAG(
- dag_id='test_change_state_for_tis_without_dagrun',
- start_date=DEFAULT_DATE)
+ dag1 = DAG(dag_id='test_change_state_for_tis_without_dagrun',
start_date=DEFAULT_DATE)
- DummyOperator(
- task_id='dummy',
- dag=dag1,
- owner='airflow')
- DummyOperator(
- task_id='dummy_b',
- dag=dag1,
- owner='airflow')
+ DummyOperator(task_id='dummy', dag=dag1, owner='airflow')
- dag2 = DAG(
- dag_id='test_change_state_for_tis_without_dagrun_dont_change',
- start_date=DEFAULT_DATE)
+ DummyOperator(task_id='dummy_b', dag=dag1, owner='airflow')
- DummyOperator(
- task_id='dummy',
- dag=dag2,
- owner='airflow')
+ dag2 =
DAG(dag_id='test_change_state_for_tis_without_dagrun_dont_change',
start_date=DEFAULT_DATE)
- dag3 = DAG(
- dag_id='test_change_state_for_tis_without_dagrun_no_dagrun',
- start_date=DEFAULT_DATE)
+ DummyOperator(task_id='dummy', dag=dag2, owner='airflow')
- DummyOperator(
- task_id='dummy',
- dag=dag3,
- owner='airflow')
+ dag3 =
DAG(dag_id='test_change_state_for_tis_without_dagrun_no_dagrun',
start_date=DEFAULT_DATE)
+
+ DummyOperator(task_id='dummy', dag=dag3, owner='airflow')
session = settings.Session()
dr1 = dag1.create_dagrun(run_id=DagRun.ID_PREFIX,
- state=State.RUNNING,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE,
- session=session)
+ state=State.RUNNING,
+ execution_date=DEFAULT_DATE,
+ start_date=DEFAULT_DATE,
+ session=session)
dr2 = dag2.create_dagrun(run_id=DagRun.ID_PREFIX,
state=State.RUNNING,
@@ -2344,7 +2335,7 @@ def test_scheduler_do_not_schedule_removed_task(self):
dag = DAG(
dag_id='test_scheduler_do_not_schedule_removed_task',
start_date=DEFAULT_DATE)
- dag_task1 = DummyOperator(
+ DummyOperator(
task_id='dummy',
dag=dag,
owner='airflow')
@@ -2374,7 +2365,7 @@ def test_scheduler_do_not_schedule_too_early(self):
dag = DAG(
dag_id='test_scheduler_do_not_schedule_too_early',
start_date=timezone.datetime(2200, 1, 1))
- dag_task1 = DummyOperator(
+ DummyOperator(
task_id='dummy',
dag=dag,
owner='airflow')
@@ -2400,7 +2391,7 @@ def test_scheduler_do_not_run_finished(self):
dag = DAG(
dag_id='test_scheduler_do_not_run_finished',
start_date=DEFAULT_DATE)
- dag_task1 = DummyOperator(
+ DummyOperator(
task_id='dummy',
dag=dag,
owner='airflow')
@@ -2436,7 +2427,7 @@ def test_scheduler_add_new_task(self):
dag_id='test_scheduler_add_new_task',
start_date=DEFAULT_DATE)
- dag_task1 = DummyOperator(
+ DummyOperator(
task_id='dummy',
dag=dag,
owner='airflow')
@@ -2456,7 +2447,7 @@ def test_scheduler_add_new_task(self):
tis = dr.get_task_instances()
self.assertEquals(len(tis), 1)
- dag_task2 = DummyOperator(
+ DummyOperator(
task_id='dummy2',
dag=dag,
owner='airflow')
@@ -2476,7 +2467,7 @@ def test_scheduler_verify_max_active_runs(self):
start_date=DEFAULT_DATE)
dag.max_active_runs = 1
- dag_task1 = DummyOperator(
+ DummyOperator(
task_id='dummy',
dag=dag,
owner='airflow')
@@ -2505,7 +2496,7 @@ def test_scheduler_fail_dagrun_timeout(self):
start_date=DEFAULT_DATE)
dag.dagrun_timeout = datetime.timedelta(seconds=60)
- dag_task1 = DummyOperator(
+ DummyOperator(
task_id='dummy',
dag=dag,
owner='airflow')
@@ -2532,8 +2523,11 @@ def test_scheduler_fail_dagrun_timeout(self):
def test_scheduler_verify_max_active_runs_and_dagrun_timeout(self):
"""
- Test if a a dagrun will not be scheduled if max_dag_runs has been
reached and dagrun_timeout is not reached
- Test if a a dagrun will be scheduled if max_dag_runs has been reached
but dagrun_timeout is also reached
+ Test if a a dagrun will not be scheduled if max_dag_runs
+ has been reached and dagrun_timeout is not reached
+
+ Test if a a dagrun will be scheduled if max_dag_runs has
+ been reached but dagrun_timeout is also reached
"""
dag = DAG(
dag_id='test_scheduler_verify_max_active_runs_and_dagrun_timeout',
@@ -2541,7 +2535,7 @@ def
test_scheduler_verify_max_active_runs_and_dagrun_timeout(self):
dag.max_active_runs = 1
dag.dagrun_timeout = datetime.timedelta(seconds=60)
- dag_task1 = DummyOperator(
+ DummyOperator(
task_id='dummy',
dag=dag,
owner='airflow')
@@ -2677,7 +2671,7 @@ def test_scheduler_auto_align(self):
start_date=timezone.datetime(2016, 1, 1, 10, 10, 0),
schedule_interval="4 5 * * *"
)
- dag_task1 = DummyOperator(
+ DummyOperator(
task_id='dummy',
dag=dag,
owner='airflow')
@@ -2699,7 +2693,7 @@ def test_scheduler_auto_align(self):
start_date=timezone.datetime(2016, 1, 1, 10, 10, 0),
schedule_interval="10 10 * * *"
)
- dag_task1 = DummyOperator(
+ DummyOperator(
task_id='dummy',
dag=dag,
owner='airflow')
@@ -2730,7 +2724,7 @@ def test_scheduler_reschedule(self):
dag = DAG(
dag_id='test_scheduler_reschedule',
start_date=DEFAULT_DATE)
- dag_task1 = DummyOperator(
+ DummyOperator(
task_id='dummy',
dag=dag,
owner='airflow')
@@ -2775,7 +2769,8 @@ def test_scheduler_sla_miss_callback(self):
# Mock the callback function so we can verify that it was not called
sla_callback = MagicMock()
- # Create dag with a start of 2 days ago, but an sla of 1 day ago so
we'll already have an sla_miss on the books
+ # Create dag with a start of 2 days ago, but an sla of 1 day
+ # ago so we'll already have an sla_miss on the books
test_start_date = days_ago(2)
dag = DAG(dag_id='test_sla_miss',
sla_miss_callback=sla_callback,
@@ -2982,8 +2977,8 @@ def test_retry_handling_job(self):
scheduler.run()
session = settings.Session()
- ti = session.query(TI).filter(TI.dag_id==dag.dag_id,
- TI.task_id==dag_task1.task_id).first()
+ ti = session.query(TI).filter(TI.dag_id == dag.dag_id,
+ TI.task_id == dag_task1.task_id).first()
# make sure the counter has increased
self.assertEqual(ti.try_number, 2)
@@ -3043,7 +3038,8 @@ def test_dag_get_active_runs(self):
"""
now = timezone.utcnow()
- six_hours_ago_to_the_hour = (now -
datetime.timedelta(hours=6)).replace(minute=0, second=0, microsecond=0)
+ six_hours_ago_to_the_hour = \
+ (now - datetime.timedelta(hours=6)).replace(minute=0, second=0,
microsecond=0)
START_DATE = six_hours_ago_to_the_hour
DAG_NAME1 = 'get_active_runs_test'
@@ -3086,7 +3082,7 @@ def test_dag_get_active_runs(self):
try:
running_date = running_dates[0]
- except:
+ except Exception as _:
running_date = 'Except'
self.assertEqual(execution_date, running_date, 'Running Date must
match Execution Date')
@@ -3169,7 +3165,6 @@ def setup_dag(dag_id, schedule_interval, start_date,
catchup):
dr = scheduler.create_dag_run(dag4)
self.assertIsNotNone(dr)
-
def
test_add_unparseable_file_before_sched_start_creates_import_error(self):
try:
dags_folder = mkdtemp()
diff --git a/tests/models.py b/tests/models.py
index 1479b63ec4..f2d36a263b 100644
--- a/tests/models.py
+++ b/tests/models.py
@@ -636,10 +636,10 @@ def test_dagstats_crud(self):
default_args={'owner': 'owner1'})
with dag:
- op1 = DummyOperator(task_id='A')
+ DummyOperator(task_id='A')
now = timezone.utcnow()
- dr = dag.create_dagrun(
+ dag.create_dagrun(
run_id='manual__' + now.isoformat(),
execution_date=now,
start_date=now,
@@ -914,8 +914,8 @@ def test_dagrun_no_deadlock_with_depends_on_past(self):
dag = DAG('test_dagrun_no_deadlock',
start_date=DEFAULT_DATE)
with dag:
- op1 = DummyOperator(task_id='dop', depends_on_past=True)
- op2 = DummyOperator(task_id='tc', task_concurrency=1)
+ DummyOperator(task_id='dop', depends_on_past=True)
+ DummyOperator(task_id='tc', task_concurrency=1)
dag.clear()
dr = dag.create_dagrun(run_id='test_dagrun_no_deadlock_1',
@@ -927,9 +927,9 @@ def test_dagrun_no_deadlock_with_depends_on_past(self):
execution_date=DEFAULT_DATE +
datetime.timedelta(days=1),
start_date=DEFAULT_DATE +
datetime.timedelta(days=1))
ti1_op1 = dr.get_task_instance(task_id='dop')
- ti2_op1 = dr2.get_task_instance(task_id='dop')
+ dr2.get_task_instance(task_id='dop')
ti2_op1 = dr.get_task_instance(task_id='tc')
- ti2_op2 = dr.get_task_instance(task_id='tc')
+ dr.get_task_instance(task_id='tc')
ti1_op1.set_state(state=State.RUNNING, session=session)
dr.update_state()
dr2.update_state()
@@ -1130,7 +1130,7 @@ def test_get_task_instance_on_empty_dagrun(self):
dag_id='test_get_task_instance_on_empty_dagrun',
start_date=timezone.datetime(2017, 1, 1)
)
- dag_task1 = ShortCircuitOperator(
+ ShortCircuitOperator(
task_id='test_short_circuit_false',
dag=dag,
python_callable=lambda: False)
@@ -1160,10 +1160,8 @@ def test_get_latest_runs(self):
dag = DAG(
dag_id='test_latest_runs_1',
start_date=DEFAULT_DATE)
- dag_1_run_1 = self.create_dag_run(dag,
-
execution_date=timezone.datetime(2015, 1, 1))
- dag_1_run_2 = self.create_dag_run(dag,
-
execution_date=timezone.datetime(2015, 1, 2))
+ self.create_dag_run(dag, execution_date=timezone.datetime(2015, 1, 1))
+ self.create_dag_run(dag, execution_date=timezone.datetime(2015, 1, 2))
dagruns = models.DagRun.get_latest_runs(session)
session.close()
for dagrun in dagruns:
@@ -1315,9 +1313,9 @@ def process_file(self, filepath, only_if_updated=True,
safe_mode=True):
super(TestDagBag, self).process_file(filepath,
only_if_updated, safe_mode)
dagbag = TestDagBag(include_examples=True)
- processed_files = dagbag.process_file_calls
+ dagbag.process_file_calls
- # Should not call process_file agani, since it's already loaded during
init.
+ # Should not call process_file again, since it's already loaded during
init.
self.assertEqual(1, dagbag.process_file_calls)
self.assertIsNotNone(dagbag.get_dag(dag_id))
self.assertEqual(1, dagbag.process_file_calls)
diff --git a/tests/operators/__init__.py b/tests/operators/__init__.py
index 4924560599..114d189da1 100644
--- a/tests/operators/__init__.py
+++ b/tests/operators/__init__.py
@@ -16,4 +16,3 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
diff --git a/tests/operators/docker_operator.py
b/tests/operators/docker_operator.py
index a216d9bd50..ea90c53c28 100644
--- a/tests/operators/docker_operator.py
+++ b/tests/operators/docker_operator.py
@@ -237,5 +237,6 @@ def test_execute_with_docker_conn_id_use_hook(self,
operator_client_mock,
'Image was not pulled using operator client'
)
+
if __name__ == "__main__":
unittest.main()
diff --git a/tests/operators/latest_only_operator.py
b/tests/operators/latest_only_operator.py
index ffce39f569..7a4efa11b5 100644
--- a/tests/operators/latest_only_operator.py
+++ b/tests/operators/latest_only_operator.py
@@ -91,7 +91,7 @@ def test_skipping(self):
self.assertEqual({
timezone.datetime(2016, 1, 1): 'success',
timezone.datetime(2016, 1, 1, 12): 'success',
- timezone.datetime(2016, 1, 2): 'success', },
+ timezone.datetime(2016, 1, 2): 'success'},
exec_date_to_latest_state)
downstream_instances = get_task_instances('downstream')
@@ -100,7 +100,7 @@ def test_skipping(self):
self.assertEqual({
timezone.datetime(2016, 1, 1): 'skipped',
timezone.datetime(2016, 1, 1, 12): 'skipped',
- timezone.datetime(2016, 1, 2): 'success',},
+ timezone.datetime(2016, 1, 2): 'success'},
exec_date_to_downstream_state)
downstream_instances = get_task_instances('downstream_2')
@@ -109,7 +109,7 @@ def test_skipping(self):
self.assertEqual({
timezone.datetime(2016, 1, 1): 'skipped',
timezone.datetime(2016, 1, 1, 12): 'skipped',
- timezone.datetime(2016, 1, 2): 'success',},
+ timezone.datetime(2016, 1, 2): 'success'},
exec_date_to_downstream_state)
def test_skipping_dagrun(self):
@@ -126,21 +126,21 @@ def test_skipping_dagrun(self):
downstream_task.set_upstream(latest_task)
downstream_task2.set_upstream(downstream_task)
- dr1 = self.dag.create_dagrun(
+ self.dag.create_dagrun(
run_id="manual__1",
start_date=timezone.utcnow(),
execution_date=DEFAULT_DATE,
state=State.RUNNING
)
- dr2 = self.dag.create_dagrun(
+ self.dag.create_dagrun(
run_id="manual__2",
start_date=timezone.utcnow(),
execution_date=timezone.datetime(2016, 1, 1, 12),
state=State.RUNNING
)
- dr2 = self.dag.create_dagrun(
+ self.dag.create_dagrun(
run_id="manual__3",
start_date=timezone.utcnow(),
execution_date=END_DATE,
@@ -157,7 +157,7 @@ def test_skipping_dagrun(self):
self.assertEqual({
timezone.datetime(2016, 1, 1): 'success',
timezone.datetime(2016, 1, 1, 12): 'success',
- timezone.datetime(2016, 1, 2): 'success', },
+ timezone.datetime(2016, 1, 2): 'success'},
exec_date_to_latest_state)
downstream_instances = get_task_instances('downstream')
@@ -166,7 +166,7 @@ def test_skipping_dagrun(self):
self.assertEqual({
timezone.datetime(2016, 1, 1): 'skipped',
timezone.datetime(2016, 1, 1, 12): 'skipped',
- timezone.datetime(2016, 1, 2): 'success',},
+ timezone.datetime(2016, 1, 2): 'success'},
exec_date_to_downstream_state)
downstream_instances = get_task_instances('downstream_2')
@@ -175,5 +175,5 @@ def test_skipping_dagrun(self):
self.assertEqual({
timezone.datetime(2016, 1, 1): 'skipped',
timezone.datetime(2016, 1, 1, 12): 'skipped',
- timezone.datetime(2016, 1, 2): 'success',},
+ timezone.datetime(2016, 1, 2): 'success'},
exec_date_to_downstream_state)
diff --git a/tests/operators/python_operator.py
b/tests/operators/python_operator.py
index 735a4d78c6..afc2a1383a 100644
--- a/tests/operators/python_operator.py
+++ b/tests/operators/python_operator.py
@@ -275,8 +275,8 @@ def test_without_dag_run(self):
value = False
dag = DAG('shortcircuit_operator_test_without_dag_run',
default_args={
- 'owner': 'airflow',
- 'start_date': DEFAULT_DATE
+ 'owner': 'airflow',
+ 'start_date': DEFAULT_DATE
},
schedule_interval=INTERVAL)
short_op = ShortCircuitOperator(task_id='make_choice',
@@ -330,8 +330,8 @@ def test_with_dag_run(self):
value = False
dag = DAG('shortcircuit_operator_test_with_dag_run',
default_args={
- 'owner': 'airflow',
- 'start_date': DEFAULT_DATE
+ 'owner': 'airflow',
+ 'start_date': DEFAULT_DATE
},
schedule_interval=INTERVAL)
short_op = ShortCircuitOperator(task_id='make_choice',
diff --git a/tests/operators/s3_to_hive_operator.py
b/tests/operators/s3_to_hive_operator.py
index 21ef29e25b..bafa3eb191 100644
--- a/tests/operators/s3_to_hive_operator.py
+++ b/tests/operators/s3_to_hive_operator.py
@@ -18,6 +18,7 @@
# under the License.
import unittest
+
try:
from unittest import mock
except ImportError:
@@ -31,7 +32,7 @@
from collections import OrderedDict
from airflow.exceptions import AirflowException
from tempfile import NamedTemporaryFile, mkdtemp
-import gzip
+from gzip import GzipFile
import bz2
import shutil
import filecmp
@@ -85,39 +86,31 @@ def setUp(self):
self._set_fn(f_txt_h.name, '.txt', True)
f_txt_h.writelines([header, line1, line2])
fn_gz = self._get_fn('.txt', True) + ".gz"
- with gzip.GzipFile(filename=fn_gz,
- mode="wb") as f_gz_h:
+ with GzipFile(filename=fn_gz, mode="wb") as f_gz_h:
self._set_fn(fn_gz, '.gz', True)
f_gz_h.writelines([header, line1, line2])
fn_gz_upper = self._get_fn('.txt', True) + ".GZ"
- with gzip.GzipFile(filename=fn_gz_upper,
- mode="wb") as f_gz_upper_h:
+ with GzipFile(filename=fn_gz_upper, mode="wb") as f_gz_upper_h:
self._set_fn(fn_gz_upper, '.GZ', True)
f_gz_upper_h.writelines([header, line1, line2])
fn_bz2 = self._get_fn('.txt', True) + '.bz2'
- with bz2.BZ2File(filename=fn_bz2,
- mode="wb") as f_bz2_h:
+ with bz2.BZ2File(filename=fn_bz2, mode="wb") as f_bz2_h:
self._set_fn(fn_bz2, '.bz2', True)
f_bz2_h.writelines([header, line1, line2])
# create sample txt, bz and bz2 without header
- with NamedTemporaryFile(mode='wb+',
- dir=self.tmp_dir,
- delete=False) as f_txt_nh:
+ with NamedTemporaryFile(mode='wb+', dir=self.tmp_dir,
delete=False) as f_txt_nh:
self._set_fn(f_txt_nh.name, '.txt', False)
f_txt_nh.writelines([line1, line2])
fn_gz = self._get_fn('.txt', False) + ".gz"
- with gzip.GzipFile(filename=fn_gz,
- mode="wb") as f_gz_nh:
+ with GzipFile(filename=fn_gz, mode="wb") as f_gz_nh:
self._set_fn(fn_gz, '.gz', False)
f_gz_nh.writelines([line1, line2])
fn_gz_upper = self._get_fn('.txt', False) + ".GZ"
- with gzip.GzipFile(filename=fn_gz_upper,
- mode="wb") as f_gz_upper_nh:
+ with GzipFile(filename=fn_gz_upper, mode="wb") as f_gz_upper_nh:
self._set_fn(fn_gz_upper, '.GZ', False)
f_gz_upper_nh.writelines([line1, line2])
fn_bz2 = self._get_fn('.txt', False) + '.bz2'
- with bz2.BZ2File(filename=fn_bz2,
- mode="wb") as f_bz2_nh:
+ with bz2.BZ2File(filename=fn_bz2, mode="wb") as f_bz2_nh:
self._set_fn(fn_bz2, '.bz2', False)
f_bz2_nh.writelines([line1, line2])
# Base Exception so it catches Keyboard Interrupt
@@ -156,15 +149,13 @@ def _check_file_equality(fn_1, fn_2, ext):
# causes filecmp to return False even if contents are identical
# Hence decompress to test for equality
if ext.lower() == '.gz':
- with gzip.GzipFile(fn_1, 'rb') as f_1,\
- NamedTemporaryFile(mode='wb') as f_txt_1,\
- gzip.GzipFile(fn_2, 'rb') as f_2,\
- NamedTemporaryFile(mode='wb') as f_txt_2:
- shutil.copyfileobj(f_1, f_txt_1)
- shutil.copyfileobj(f_2, f_txt_2)
- f_txt_1.flush()
- f_txt_2.flush()
- return filecmp.cmp(f_txt_1.name, f_txt_2.name, shallow=False)
+ with GzipFile(fn_1, 'rb') as f_1, NamedTemporaryFile(mode='wb') as
f_txt_1:
+ with GzipFile(fn_2, 'rb') as f_2,
NamedTemporaryFile(mode='wb') as f_txt_2:
+ shutil.copyfileobj(f_1, f_txt_1)
+ shutil.copyfileobj(f_2, f_txt_2)
+ f_txt_1.flush()
+ f_txt_2.flush()
+ return filecmp.cmp(f_txt_1.name, f_txt_2.name,
shallow=False)
else:
return filecmp.cmp(fn_1, fn_2, shallow=False)
@@ -179,20 +170,20 @@ def test_bad_parameters(self):
def test__get_top_row_as_list(self):
self.kwargs['delimiter'] = '\t'
fn_txt = self._get_fn('.txt', True)
- header_list = S3ToHiveTransfer(**self.kwargs).\
+ header_list = S3ToHiveTransfer(**self.kwargs). \
_get_top_row_as_list(fn_txt)
self.assertEqual(header_list, ['Sno', 'Some,Text'],
msg="Top row from file doesnt matched expected value")
self.kwargs['delimiter'] = ','
- header_list = S3ToHiveTransfer(**self.kwargs).\
+ header_list = S3ToHiveTransfer(**self.kwargs). \
_get_top_row_as_list(fn_txt)
self.assertEqual(header_list, ['Sno\tSome', 'Text'],
msg="Top row from file doesnt matched expected value")
def test__match_headers(self):
self.kwargs['field_dict'] = OrderedDict([('Sno', 'BIGINT'),
- ('Some,Text', 'STRING')])
+ ('Some,Text', 'STRING')])
self.assertTrue(S3ToHiveTransfer(**self.kwargs).
_match_headers(['Sno', 'Some,Text']),
msg="Header row doesnt match expected value")
@@ -250,8 +241,7 @@ def test_execute(self, mock_hiveclihook):
# file parameter to HiveCliHook.load_file is compared
# against expected file output
mock_hiveclihook().load_file.side_effect = \
- lambda *args, **kwargs: \
- self.assertTrue(
+ lambda *args, **kwargs: self.assertTrue(
self._check_file_equality(args[0], op_fn, ext),
msg='{0} output file not as expected'.format(ext))
# Execute S3ToHiveTransfer
diff --git a/tests/operators/subdag_operator.py
b/tests/operators/subdag_operator.py
index af47c5cfd5..bba6f17194 100644
--- a/tests/operators/subdag_operator.py
+++ b/tests/operators/subdag_operator.py
@@ -85,7 +85,7 @@ def test_subdag_pools(self):
session.add(pool_10)
session.commit()
- dummy_1 = DummyOperator(task_id='dummy', dag=subdag,
pool='test_pool_1')
+ DummyOperator(task_id='dummy', dag=subdag, pool='test_pool_1')
self.assertRaises(
AirflowException,
@@ -116,8 +116,7 @@ def test_subdag_pools_no_possible_conflict(self):
session.add(pool_10)
session.commit()
- dummy_1 = DummyOperator(
- task_id='dummy', dag=subdag, pool='test_pool_10')
+ DummyOperator(task_id='dummy', dag=subdag, pool='test_pool_10')
mock_session = Mock()
SubDagOperator(
diff --git a/tests/plugins/test_plugin.py b/tests/plugins/test_plugin.py
index f507bdd585..7b8bc4f61c 100644
--- a/tests/plugins/test_plugin.py
+++ b/tests/plugins/test_plugin.py
@@ -30,10 +30,12 @@
from airflow.sensors.base_sensor_operator import BaseSensorOperator
from airflow.executors.base_executor import BaseExecutor
+
# Will show up under airflow.hooks.test_plugin.PluginHook
class PluginHook(BaseHook):
pass
+
# Will show up under airflow.operators.test_plugin.PluginOperator
class PluginOperator(BaseOperator):
pass
@@ -48,22 +50,27 @@ class PluginSensorOperator(BaseSensorOperator):
class PluginExecutor(BaseExecutor):
pass
+
# Will show up under airflow.macros.test_plugin.plugin_macro
def plugin_macro():
pass
+
# Creating a flask admin BaseView
class TestView(BaseView):
@expose('/')
def test(self):
- # in this example, put your test_plugin/test.html template at
airflow/plugins/templates/test_plugin/test.html
+ # in this example, put your test_plugin/test.html
+ # template at airflow/plugins/templates/test_plugin/test.html
return self.render("test_plugin/test.html", content="Hello galaxy!")
+
+
v = TestView(category="Test Plugin", name="Test View")
# Creating a flask blueprint to intergrate the templates and static folder
bp = Blueprint(
"test_plugin", __name__,
- template_folder='templates', # registers airflow/plugins/templates as a
Jinja template folder
+ template_folder='templates', # registers airflow/plugins/templates as a
Jinja template folder
static_folder='static',
static_url_path='/static/test_plugin')
diff --git a/tests/plugins_manager.py b/tests/plugins_manager.py
index 39da5ce448..9f939c37a0 100644
--- a/tests/plugins_manager.py
+++ b/tests/plugins_manager.py
@@ -28,7 +28,7 @@
from flask_admin.menu import MenuLink, MenuView
from airflow.hooks.base_hook import BaseHook
-from airflow.models import BaseOperator
+from airflow.models import BaseOperator
from airflow.sensors.base_sensor_operator import BaseSensorOperator
from airflow.executors.base_executor import BaseExecutor
from airflow.www.app import cached_app
diff --git a/tests/sensors/test_sql_sensor.py b/tests/sensors/test_sql_sensor.py
index 5b52f3b654..03ea115356 100644
--- a/tests/sensors/test_sql_sensor.py
+++ b/tests/sensors/test_sql_sensor.py
@@ -40,8 +40,8 @@ def setUp(self):
}
self.dag = DAG(TEST_DAG_ID, default_args=args)
- @unittest.skipUnless('mysql' in configuration.conf.get('core',
'sql_alchemy_conn'),
- "this is a mysql test")
+ @unittest.skipUnless(
+ 'mysql' in configuration.conf.get('core', 'sql_alchemy_conn'), "this
is a mysql test")
def test_sql_sensor_mysql(self):
t = SqlSensor(
task_id='sql_sensor_check',
@@ -51,8 +51,8 @@ def test_sql_sensor_mysql(self):
)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
ignore_ti_state=True)
- @unittest.skipUnless('postgresql' in configuration.conf.get('core',
'sql_alchemy_conn'),
- "this is a postgres test")
+ @unittest.skipUnless(
+ 'postgresql' in configuration.conf.get('core', 'sql_alchemy_conn'),
"this is a postgres test")
def test_sql_sensor_postgres(self):
t = SqlSensor(
task_id='sql_sensor_check',
@@ -70,9 +70,7 @@ def test_sql_sensor_postgres_poke(self, mock_hook):
sql="SELECT 1",
)
- mock_get_records = (
- mock_hook.get_connection.return_value
- .get_hook.return_value.get_records)
+ mock_get_records =
mock_hook.get_connection.return_value.get_hook.return_value.get_records
mock_get_records.return_value = []
self.assertFalse(t.poke(None))
diff --git a/tests/task/task_runner/__init__.py
b/tests/task/task_runner/__init__.py
index 4067cc78ee..114d189da1 100644
--- a/tests/task/task_runner/__init__.py
+++ b/tests/task/task_runner/__init__.py
@@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
diff --git a/tests/test_logging_config.py b/tests/test_logging_config.py
index b568a88cb2..324e33a3ac 100644
--- a/tests/test_logging_config.py
+++ b/tests/test_logging_config.py
@@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -148,7 +148,7 @@ def __enter__(self):
return self.settings_file
def __exit__(self, *exc_info):
- #shutil.rmtree(self.settings_root)
+ # shutil.rmtree(self.settings_root)
# Reset config
conf.set('core', 'logging_config_class', '')
sys.path.remove(self.settings_root)
diff --git a/tests/ti_deps/deps/test_runnable_exec_date_dep.py
b/tests/ti_deps/deps/test_runnable_exec_date_dep.py
index 8f59b43ba4..16057c8108 100644
--- a/tests/ti_deps/deps/test_runnable_exec_date_dep.py
+++ b/tests/ti_deps/deps/test_runnable_exec_date_dep.py
@@ -25,6 +25,7 @@
from airflow.ti_deps.deps.runnable_exec_date_dep import RunnableExecDateDep
from airflow.utils.timezone import datetime
+
class RunnableExecDateDepTest(unittest.TestCase):
def _get_task_instance(self, execution_date, dag_end_date=None,
task_end_date=None):
diff --git a/tests/ti_deps/deps/test_task_concurrency.py
b/tests/ti_deps/deps/test_task_concurrency.py
index 940bfca361..ad3c0d9ea0 100644
--- a/tests/ti_deps/deps/test_task_concurrency.py
+++ b/tests/ti_deps/deps/test_task_concurrency.py
@@ -52,4 +52,3 @@ def test_reached_concurrency(self):
self.assertTrue(TaskConcurrencyDep().is_met(ti=ti,
dep_context=dep_context))
ti.get_num_running_task_instances = lambda x: 2
self.assertFalse(TaskConcurrencyDep().is_met(ti=ti,
dep_context=dep_context))
-
diff --git a/tests/utils.py b/tests/utils.py
index 00f126c5da..f670e41183 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -62,6 +62,7 @@ def test_gcs_url_parse(self):
glog.parse_gcs_url('gs://bucket/'),
('bucket', ''))
+
class OperatorResourcesTest(unittest.TestCase):
def setUp(self):
diff --git a/tests/utils/test_dates.py b/tests/utils/test_dates.py
index 84ffb4d7d8..85613d0791 100644
--- a/tests/utils/test_dates.py
+++ b/tests/utils/test_dates.py
@@ -33,26 +33,22 @@ def test_days_ago(self):
self.assertTrue(dates.days_ago(0) == today_midnight)
- self.assertTrue(
- dates.days_ago(100) == today_midnight + timedelta(days=-100))
-
- self.assertTrue(
- dates.days_ago(0, hour=3) == today_midnight + timedelta(hours=3))
- self.assertTrue(
- dates.days_ago(0, minute=3)
- == today_midnight + timedelta(minutes=3))
- self.assertTrue(
- dates.days_ago(0, second=3)
- == today_midnight + timedelta(seconds=3))
- self.assertTrue(
- dates.days_ago(0, microsecond=3)
- == today_midnight + timedelta(microseconds=3))
+ self.assertTrue(dates.days_ago(100) == today_midnight +
timedelta(days=-100))
+
+ self.assertTrue(dates.days_ago(0, hour=3) == today_midnight +
timedelta(hours=3))
+ self.assertTrue(dates.days_ago(0, minute=3) == today_midnight +
timedelta(minutes=3))
+ self.assertTrue(dates.days_ago(0, second=3) == today_midnight +
timedelta(seconds=3))
+ self.assertTrue(dates.days_ago(0, microsecond=3) == today_midnight +
timedelta(microseconds=3))
def test_parse_execution_date(self):
execution_date_str_wo_ms = '2017-11-02 00:00:00'
execution_date_str_w_ms = '2017-11-05 16:18:30.989729'
bad_execution_date_str = '2017-11-06TXX:00:00Z'
- self.assertEqual(timezone.datetime(2017, 11, 2, 0, 0, 0),
dates.parse_execution_date(execution_date_str_wo_ms))
- self.assertEqual(timezone.datetime(2017, 11, 5, 16, 18, 30, 989729),
dates.parse_execution_date(execution_date_str_w_ms))
+ self.assertEqual(
+ timezone.datetime(2017, 11, 2, 0, 0, 0),
+ dates.parse_execution_date(execution_date_str_wo_ms))
+ self.assertEqual(
+ timezone.datetime(2017, 11, 5, 16, 18, 30, 989729),
+ dates.parse_execution_date(execution_date_str_w_ms))
self.assertRaises(ValueError, dates.parse_execution_date,
bad_execution_date_str)
diff --git a/tests/utils/test_log_handlers.py b/tests/utils/test_log_handlers.py
index ffd6df4b04..72a5793f7e 100644
--- a/tests/utils/test_log_handlers.py
+++ b/tests/utils/test_log_handlers.py
@@ -169,14 +169,18 @@ def setUp(self):
self.ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
def test_python_formatting(self):
- expected_filename =
'dag_for_testing_filename_rendering/task_for_testing_filename_rendering/%s/42.log'
% DEFAULT_DATE.isoformat()
+ expected_filename = \
+
'dag_for_testing_filename_rendering/task_for_testing_filename_rendering/%s/42.log'
\
+ % DEFAULT_DATE.isoformat()
fth = FileTaskHandler('',
'{dag_id}/{task_id}/{execution_date}/{try_number}.log')
rendered_filename = fth._render_filename(self.ti, 42)
self.assertEqual(expected_filename, rendered_filename)
def test_jinja_rendering(self):
- expected_filename =
'dag_for_testing_filename_rendering/task_for_testing_filename_rendering/%s/42.log'
% DEFAULT_DATE.isoformat()
+ expected_filename = \
+
'dag_for_testing_filename_rendering/task_for_testing_filename_rendering/%s/42.log'
\
+ % DEFAULT_DATE.isoformat()
fth = FileTaskHandler('', '{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts
}}/{{ try_number }}.log')
rendered_filename = fth._render_filename(self.ti, 42)
diff --git a/tests/utils/test_logging_mixin.py
b/tests/utils/test_logging_mixin.py
index df7423092f..fa8c589a9e 100644
--- a/tests/utils/test_logging_mixin.py
+++ b/tests/utils/test_logging_mixin.py
@@ -120,4 +120,3 @@ def test_encoding(self):
log = StreamLogWriter(logger, 1)
self.assertFalse(log.encoding)
-
diff --git a/tests/utils/test_timezone.py b/tests/utils/test_timezone.py
index 92fd54ec7a..07e0befcb8 100644
--- a/tests/utils/test_timezone.py
+++ b/tests/utils/test_timezone.py
@@ -24,8 +24,8 @@
from airflow.utils import timezone
CET = pendulum.timezone("Europe/Paris")
-EAT = pendulum.timezone('Africa/Nairobi') # Africa/Nairobi
-ICT = pendulum.timezone('Asia/Bangkok') # Asia/Bangkok
+EAT = pendulum.timezone('Africa/Nairobi') # Africa/Nairobi
+ICT = pendulum.timezone('Asia/Bangkok') # Asia/Bangkok
UTC = timezone.utc
@@ -69,4 +69,3 @@ def test_make_aware(self):
datetime.datetime(2011, 9, 1, 13, 20, 30, tzinfo=EAT))
with self.assertRaises(ValueError):
timezone.make_aware(datetime.datetime(2011, 9, 1, 13, 20, 30,
tzinfo=EAT), EAT)
-
diff --git a/tests/www/__init__.py b/tests/www/__init__.py
index 4924560599..114d189da1 100644
--- a/tests/www/__init__.py
+++ b/tests/www/__init__.py
@@ -16,4 +16,3 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
diff --git a/tests/www/api/__init__.py b/tests/www/api/__init__.py
index 4924560599..114d189da1 100644
--- a/tests/www/api/__init__.py
+++ b/tests/www/api/__init__.py
@@ -16,4 +16,3 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
diff --git a/tests/www/api/experimental/__init__.py
b/tests/www/api/experimental/__init__.py
index 4924560599..114d189da1 100644
--- a/tests/www/api/experimental/__init__.py
+++ b/tests/www/api/experimental/__init__.py
@@ -16,4 +16,3 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
diff --git a/tests/www/api/experimental/test_kerberos_endpoints.py
b/tests/www/api/experimental/test_kerberos_endpoints.py
index 63a1557432..b305e19b21 100644
--- a/tests/www/api/experimental/test_kerberos_endpoints.py
+++ b/tests/www/api/experimental/test_kerberos_endpoints.py
@@ -37,14 +37,14 @@ def setUp(self):
configuration.load_test_config()
try:
configuration.conf.add_section("api")
- except:
+ except Exception:
pass
configuration.conf.set("api",
"auth_backend",
"airflow.api.auth.backend.kerberos_auth")
try:
configuration.conf.add_section("kerberos")
- except:
+ except Exception:
pass
configuration.conf.set("kerberos",
"keytab",
diff --git a/tests/www/test_validators.py b/tests/www/test_validators.py
index 0180b734cc..e6242632f2 100644
--- a/tests/www/test_validators.py
+++ b/tests/www/test_validators.py
@@ -92,6 +92,5 @@ def test_validation_raises_custom_message(self):
)
-
if __name__ == '__main__':
unittest.main()
diff --git a/tests/www_rbac/__init__.py b/tests/www_rbac/__init__.py
index 4067cc78ee..114d189da1 100644
--- a/tests/www_rbac/__init__.py
+++ b/tests/www_rbac/__init__.py
@@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
diff --git a/tests/www_rbac/api/__init__.py b/tests/www_rbac/api/__init__.py
index 4067cc78ee..114d189da1 100644
--- a/tests/www_rbac/api/__init__.py
+++ b/tests/www_rbac/api/__init__.py
@@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
diff --git a/tests/www_rbac/api/experimental/__init__.py
b/tests/www_rbac/api/experimental/__init__.py
index 4067cc78ee..114d189da1 100644
--- a/tests/www_rbac/api/experimental/__init__.py
+++ b/tests/www_rbac/api/experimental/__init__.py
@@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
diff --git a/tests/www_rbac/api/experimental/test_kerberos_endpoints.py
b/tests/www_rbac/api/experimental/test_kerberos_endpoints.py
index 54bbd865b3..788613b504 100644
--- a/tests/www_rbac/api/experimental/test_kerberos_endpoints.py
+++ b/tests/www_rbac/api/experimental/test_kerberos_endpoints.py
@@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
diff --git a/tests/www_rbac/test_utils.py b/tests/www_rbac/test_utils.py
index 68d1744ab8..84bf6ce8ae 100644
--- a/tests/www_rbac/test_utils.py
+++ b/tests/www_rbac/test_utils.py
@@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
diff --git a/tests/www_rbac/test_validators.py
b/tests/www_rbac/test_validators.py
index d7709c4ee7..b50e88de11 100644
--- a/tests/www_rbac/test_validators.py
+++ b/tests/www_rbac/test_validators.py
@@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
diff --git a/tests/www_rbac/test_views.py b/tests/www_rbac/test_views.py
index a952b9874c..18af9541df 100644
--- a/tests/www_rbac/test_views.py
+++ b/tests/www_rbac/test_views.py
@@ -173,7 +173,7 @@ def test_xss_prevention(self):
self.assertNotIn("<img src='' onerror='alert(1);'>",
resp.data.decode("utf-8"))
- def test_import_variables(self):
+ def test_import_variables_failed(self):
content = '{"str_key": "str_value"}'
with mock.patch('airflow.models.Variable.set') as set_mock:
@@ -192,7 +192,7 @@ def test_import_variables(self):
follow_redirects=True)
self.check_content_in_response('1 variable(s) failed to be
updated.', resp)
- def test_import_variables(self):
+ def test_import_variables_success(self):
self.assertEqual(self.session.query(models.Variable).count(), 0)
content = ('{"str_key": "str_value", "int_key": 60,'
diff --git a/tox.ini b/tox.ini
index 17bfeb6ec2..6a55df6e5d 100644
--- a/tox.ini
+++ b/tox.ini
@@ -26,10 +26,6 @@ find_links =
{homedir}/.wheelhouse
{homedir}/.pip-cache
-[flake8]
-max-line-length = 90
-ignore = E731,W503
-
[testenv]
deps =
wheel
@@ -72,5 +68,4 @@ basepython = python3
deps =
flake8
-commands =
- {toxinidir}/scripts/ci/flake8-diff.sh
+commands = flake8
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services