This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch v1-10-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 9b46a44348d4460cca9059c87ec49ca5b5c629a8 Author: samuelkhtu <[email protected]> AuthorDate: Wed Jun 17 14:32:46 2020 -0400 Merging multiple sql operators (#9124) * Merge various SQL Operators into sql.py * Fix unit test code format * Merge multiple SQL operators into one 1. Merge check_operator.py into airflow.operators.sql 2. Merge sql_branch_operator.py into airflow.operators.sql 3. Merge unit test for both into test_sql.py * Rename test_core_to_contrib Interval/ValueCheckOperator to SQLInterval/ValueCheckOperator * Fixed deprecated class and added check to test_core_to_contrib (cherry picked from commit 0b9bf4a285a074bbde270839a90fb53c257340be) --- ...eea_add_precision_to_execution_date_in_mysql.py | 2 +- airflow/operators/check_operator.py | 425 +++------------------ airflow/operators/{check_operator.py => sql.py} | 419 +++++++++++++++----- airflow/operators/sql_branch_operator.py | 162 +------- docs/operators-and-hooks-ref.rst | 30 +- tests/api/common/experimental/test_pool.py | 64 ++-- tests/contrib/hooks/test_discord_webhook_hook.py | 6 +- .../contrib/operators/test_databricks_operator.py | 6 +- .../contrib/operators/test_gcs_to_gcs_operator.py | 4 +- .../operators/test_qubole_check_operator.py | 7 +- tests/contrib/operators/test_sftp_operator.py | 6 +- tests/contrib/operators/test_ssh_operator.py | 6 +- tests/contrib/operators/test_winrm_operator.py | 6 +- tests/contrib/sensors/test_weekday_sensor.py | 18 +- .../contrib/utils/test_mlengine_operator_utils.py | 16 +- tests/jobs/test_backfill_job.py | 9 +- tests/kubernetes/test_worker_configuration.py | 11 +- tests/models/test_baseoperator.py | 5 +- tests/operators/test_check_operator.py | 327 ---------------- tests/operators/test_s3_to_hive_operator.py | 12 +- .../{test_sql_branch_operator.py => test_sql.py} | 342 +++++++++++++++-- tests/secrets/test_local_filesystem.py | 16 +- tests/sensors/test_http_sensor.py | 3 +- tests/utils/test_compression.py | 16 +- tests/utils/test_decorators.py | 10 +- tests/utils/test_json.py | 11 +- tests/utils/test_module_loading.py | 4 +- tests/www/test_validators.py | 11 +- tests/www_rbac/test_validators.py | 9 +- 29 files changed, 836 insertions(+), 1127 deletions(-) diff --git a/airflow/migrations/versions/a66efa278eea_add_precision_to_execution_date_in_mysql.py b/airflow/migrations/versions/a66efa278eea_add_precision_to_execution_date_in_mysql.py index ecb589d..59098a8 100644 --- a/airflow/migrations/versions/a66efa278eea_add_precision_to_execution_date_in_mysql.py +++ b/airflow/migrations/versions/a66efa278eea_add_precision_to_execution_date_in_mysql.py @@ -29,7 +29,7 @@ from sqlalchemy.dialects import mysql # revision identifiers, used by Alembic. revision = 'a66efa278eea' -down_revision = '8f966b9c467a' +down_revision = '952da73b5eff' branch_labels = None depends_on = None diff --git a/airflow/operators/check_operator.py b/airflow/operators/check_operator.py index b6d3a18..12ac472 100644 --- a/airflow/operators/check_operator.py +++ b/airflow/operators/check_operator.py @@ -17,409 +17,70 @@ # specific language governing permissions and limitations # under the License. -from builtins import str, zip -from typing import Optional, Any, Iterable, Dict, SupportsAbs +"""This module is deprecated. Please use `airflow.operators.sql`.""" -from airflow.exceptions import AirflowException -from airflow.hooks.base_hook import BaseHook -from airflow.models import BaseOperator -from airflow.utils.decorators import apply_defaults +import warnings +from airflow.operators.sql import ( + SQLCheckOperator, SQLIntervalCheckOperator, SQLThresholdCheckOperator, SQLValueCheckOperator, +) -class CheckOperator(BaseOperator): - """ - Performs checks against a db. The ``CheckOperator`` expects - a sql query that will return a single row. Each value on that - first row is evaluated using python ``bool`` casting. If any of the - values return ``False`` the check is failed and errors out. - - Note that Python bool casting evals the following as ``False``: - - * ``False`` - * ``0`` - * Empty string (``""``) - * Empty list (``[]``) - * Empty dictionary or set (``{}``) - - Given a query like ``SELECT COUNT(*) FROM foo``, it will fail only if - the count ``== 0``. You can craft much more complex query that could, - for instance, check that the table has the same number of rows as - the source table upstream, or that the count of today's partition is - greater than yesterday's partition, or that a set of metrics are less - than 3 standard deviation for the 7 day average. - - This operator can be used as a data quality check in your pipeline, and - depending on where you put it in your DAG, you have the choice to - stop the critical path, preventing from - publishing dubious data, or on the side and receive email alerts - without stopping the progress of the DAG. - Note that this is an abstract class and get_db_hook - needs to be defined. Whereas a get_db_hook is hook that gets a - single record from an external source. - - :param sql: the sql to be executed. (templated) - :type sql: str +class CheckOperator(SQLCheckOperator): """ - - template_fields = ('sql',) # type: Iterable[str] - template_ext = ('.hql', '.sql',) # type: Iterable[str] - ui_color = '#fff7e6' - - @apply_defaults - def __init__( - self, - sql, # type: str - conn_id=None, # type: Optional[str] - *args, - **kwargs - ): - super(CheckOperator, self).__init__(*args, **kwargs) - self.conn_id = conn_id - self.sql = sql - - def execute(self, context=None): - self.log.info('Executing SQL check: %s', self.sql) - records = self.get_db_hook().get_first(self.sql) - - self.log.info('Record: %s', records) - if not records: - raise AirflowException("The query returned None") - elif not all([bool(r) for r in records]): - raise AirflowException("Test failed.\nQuery:\n{query}\nResults:\n{records!s}".format( - query=self.sql, records=records)) - - self.log.info("Success.") - - def get_db_hook(self): - return BaseHook.get_hook(conn_id=self.conn_id) - - -def _convert_to_float_if_possible(s): + This class is deprecated. + Please use `airflow.operators.sql.SQLCheckOperator`. """ - A small helper function to convert a string to a numeric value - if appropriate - :param s: the string to be converted - :type s: str - """ - try: - ret = float(s) - except (ValueError, TypeError): - ret = s - return ret + def __init__(self, *args, **kwargs): + warnings.warn( + """This class is deprecated. + Please use `airflow.operators.sql.SQLCheckOperator`.""", + DeprecationWarning, stacklevel=2 + ) + super(CheckOperator, self).__init__(*args, **kwargs) -class ValueCheckOperator(BaseOperator): +class IntervalCheckOperator(SQLIntervalCheckOperator): """ - Performs a simple value check using sql code. - - Note that this is an abstract class and get_db_hook - needs to be defined. Whereas a get_db_hook is hook that gets a - single record from an external source. - - :param sql: the sql to be executed. (templated) - :type sql: str + This class is deprecated. + Please use `airflow.operators.sql.SQLIntervalCheckOperator`. """ - __mapper_args__ = { - 'polymorphic_identity': 'ValueCheckOperator' - } - template_fields = ('sql', 'pass_value',) # type: Iterable[str] - template_ext = ('.hql', '.sql',) # type: Iterable[str] - ui_color = '#fff7e6' - - @apply_defaults - def __init__( - self, - sql, # type: str - pass_value, # type: Any - tolerance=None, # type: Any - conn_id=None, # type: Optional[str] - *args, - **kwargs - ): - super(ValueCheckOperator, self).__init__(*args, **kwargs) - self.sql = sql - self.conn_id = conn_id - self.pass_value = str(pass_value) - tol = _convert_to_float_if_possible(tolerance) - self.tol = tol if isinstance(tol, float) else None - self.has_tolerance = self.tol is not None - - def execute(self, context=None): - self.log.info('Executing SQL check: %s', self.sql) - records = self.get_db_hook().get_first(self.sql) - - if not records: - raise AirflowException("The query returned None") - - pass_value_conv = _convert_to_float_if_possible(self.pass_value) - is_numeric_value_check = isinstance(pass_value_conv, float) - - tolerance_pct_str = str(self.tol * 100) + '%' if self.has_tolerance else None - error_msg = ("Test failed.\nPass value:{pass_value_conv}\n" - "Tolerance:{tolerance_pct_str}\n" - "Query:\n{sql}\nResults:\n{records!s}").format( - pass_value_conv=pass_value_conv, - tolerance_pct_str=tolerance_pct_str, - sql=self.sql, - records=records + def __init__(self, *args, **kwargs): + warnings.warn( + """This class is deprecated. + Please use `airflow.operators.sql.SQLIntervalCheckOperator`.""", + DeprecationWarning, stacklevel=2 ) - - if not is_numeric_value_check: - tests = self._get_string_matches(records, pass_value_conv) - elif is_numeric_value_check: - try: - numeric_records = self._to_float(records) - except (ValueError, TypeError): - raise AirflowException("Converting a result to float failed.\n{}".format(error_msg)) - tests = self._get_numeric_matches(numeric_records, pass_value_conv) - else: - tests = [] - - if not all(tests): - raise AirflowException(error_msg) - - def _to_float(self, records): - return [float(record) for record in records] - - def _get_string_matches(self, records, pass_value_conv): - return [str(record) == pass_value_conv for record in records] - - def _get_numeric_matches(self, numeric_records, numeric_pass_value_conv): - if self.has_tolerance: - return [ - numeric_pass_value_conv * (1 - self.tol) <= record <= numeric_pass_value_conv * (1 + self.tol) - for record in numeric_records - ] - - return [record == numeric_pass_value_conv for record in numeric_records] - - def get_db_hook(self): - return BaseHook.get_hook(conn_id=self.conn_id) + super(IntervalCheckOperator, self).__init__(*args, **kwargs) -class IntervalCheckOperator(BaseOperator): +class ThresholdCheckOperator(SQLThresholdCheckOperator): """ - Checks that the values of metrics given as SQL expressions are within - a certain tolerance of the ones from days_back before. - - Note that this is an abstract class and get_db_hook - needs to be defined. Whereas a get_db_hook is hook that gets a - single record from an external source. - - :param table: the table name - :type table: str - :param days_back: number of days between ds and the ds we want to check - against. Defaults to 7 days - :type days_back: int - :param ratio_formula: which formula to use to compute the ratio between - the two metrics. Assuming cur is the metric of today and ref is - the metric to today - days_back. - - max_over_min: computes max(cur, ref) / min(cur, ref) - relative_diff: computes abs(cur-ref) / ref - - Default: 'max_over_min' - :type ratio_formula: str - :param ignore_zero: whether we should ignore zero metrics - :type ignore_zero: bool - :param metrics_threshold: a dictionary of ratios indexed by metrics - :type metrics_threshold: dict + This class is deprecated. + Please use `airflow.operators.sql.SQLThresholdCheckOperator`. """ - __mapper_args__ = { - 'polymorphic_identity': 'IntervalCheckOperator' - } - template_fields = ('sql1', 'sql2') # type: Iterable[str] - template_ext = ('.hql', '.sql',) # type: Iterable[str] - ui_color = '#fff7e6' - - ratio_formulas = { - 'max_over_min': lambda cur, ref: float(max(cur, ref)) / min(cur, ref), - 'relative_diff': lambda cur, ref: float(abs(cur - ref)) / ref, - } - - @apply_defaults - def __init__( - self, - table, # type: str - metrics_thresholds, # type: Dict[str, int] - date_filter_column='ds', # type: Optional[str] - days_back=-7, # type: SupportsAbs[int] - ratio_formula='max_over_min', # type: Optional[str] - ignore_zero=True, # type: Optional[bool] - conn_id=None, # type: Optional[str] - *args, **kwargs - ): - super(IntervalCheckOperator, self).__init__(*args, **kwargs) - if ratio_formula not in self.ratio_formulas: - msg_template = "Invalid diff_method: {diff_method}. " \ - "Supported diff methods are: {diff_methods}" - - raise AirflowException( - msg_template.format(diff_method=ratio_formula, - diff_methods=self.ratio_formulas) - ) - self.ratio_formula = ratio_formula - self.ignore_zero = ignore_zero - self.table = table - self.metrics_thresholds = metrics_thresholds - self.metrics_sorted = sorted(metrics_thresholds.keys()) - self.date_filter_column = date_filter_column - self.days_back = -abs(days_back) - self.conn_id = conn_id - sqlexp = ', '.join(self.metrics_sorted) - sqlt = "SELECT {sqlexp} FROM {table} WHERE {date_filter_column}=".format( - sqlexp=sqlexp, table=table, date_filter_column=date_filter_column + def __init__(self, *args, **kwargs): + warnings.warn( + """This class is deprecated. + Please use `airflow.operators.sql.SQLThresholdCheckOperator`.""", + DeprecationWarning, stacklevel=2 ) - - self.sql1 = sqlt + "'{{ ds }}'" - self.sql2 = sqlt + "'{{ macros.ds_add(ds, " + str(self.days_back) + ") }}'" - - def execute(self, context=None): - hook = self.get_db_hook() - self.log.info('Using ratio formula: %s', self.ratio_formula) - self.log.info('Executing SQL check: %s', self.sql2) - row2 = hook.get_first(self.sql2) - self.log.info('Executing SQL check: %s', self.sql1) - row1 = hook.get_first(self.sql1) - - if not row2: - raise AirflowException("The query {} returned None".format(self.sql2)) - if not row1: - raise AirflowException("The query {} returned None".format(self.sql1)) - - current = dict(zip(self.metrics_sorted, row1)) - reference = dict(zip(self.metrics_sorted, row2)) - - ratios = {} - test_results = {} - - for m in self.metrics_sorted: - cur = current[m] - ref = reference[m] - threshold = self.metrics_thresholds[m] - if cur == 0 or ref == 0: - ratios[m] = None - test_results[m] = self.ignore_zero - else: - ratios[m] = self.ratio_formulas[self.ratio_formula](current[m], reference[m]) - test_results[m] = ratios[m] < threshold - - self.log.info( - ( - "Current metric for %s: %s\n" - "Past metric for %s: %s\n" - "Ratio for %s: %s\n" - "Threshold: %s\n" - ), m, cur, m, ref, m, ratios[m], threshold) - - if not all(test_results.values()): - failed_tests = [it[0] for it in test_results.items() if not it[1]] - j = len(failed_tests) - n = len(self.metrics_sorted) - self.log.warning("The following %s tests out of %s failed:", j, n) - for k in failed_tests: - self.log.warning( - "'%s' check failed. %s is above %s", k, ratios[k], self.metrics_thresholds[k] - ) - raise AirflowException("The following tests have failed:\n {0}".format(", ".join( - sorted(failed_tests)))) - - self.log.info("All tests have passed") - - def get_db_hook(self): - return BaseHook.get_hook(conn_id=self.conn_id) + super(ThresholdCheckOperator, self).__init__(*args, **kwargs) -class ThresholdCheckOperator(BaseOperator): +class ValueCheckOperator(SQLValueCheckOperator): """ - Performs a value check using sql code against a mininmum threshold - and a maximum threshold. Thresholds can be in the form of a numeric - value OR a sql statement that results a numeric. - - Note that this is an abstract class and get_db_hook - needs to be defined. Whereas a get_db_hook is hook that gets a - single record from an external source. - - :param sql: the sql to be executed. (templated) - :type sql: str - :param min_threshold: numerical value or min threshold sql to be executed (templated) - :type min_threshold: numeric or str - :param max_threshold: numerical value or max threshold sql to be executed (templated) - :type max_threshold: numeric or str + This class is deprecated. + Please use `airflow.operators.sql.SQLValueCheckOperator`. """ - template_fields = ('sql', 'min_threshold', 'max_threshold') # type: Iterable[str] - template_ext = ('.hql', '.sql',) # type: Iterable[str] - - @apply_defaults - def __init__( - self, - sql, # type: str - min_threshold, # type: Any - max_threshold, # type: Any - conn_id=None, # type: Optional[str] - *args, **kwargs - ): - super(ThresholdCheckOperator, self).__init__(*args, **kwargs) - self.sql = sql - self.conn_id = conn_id - self.min_threshold = _convert_to_float_if_possible(min_threshold) - self.max_threshold = _convert_to_float_if_possible(max_threshold) - - def execute(self, context=None): - hook = self.get_db_hook() - result = hook.get_first(self.sql)[0][0] - - if isinstance(self.min_threshold, float): - lower_bound = self.min_threshold - else: - lower_bound = hook.get_first(self.min_threshold)[0][0] - - if isinstance(self.max_threshold, float): - upper_bound = self.max_threshold - else: - upper_bound = hook.get_first(self.max_threshold)[0][0] - - meta_data = { - "result": result, - "task_id": self.task_id, - "min_threshold": lower_bound, - "max_threshold": upper_bound, - "within_threshold": lower_bound <= result <= upper_bound - } - - self.push(meta_data) - if not meta_data["within_threshold"]: - error_msg = ( - 'Threshold Check: "{task_id}" failed.\n' - 'DAG: {dag_id}\nTask_id: {task_id}\n' - 'Check description: {description}\n' - 'SQL: {sql}\n' - 'Result: {result} is not within thresholds ' - '{min_threshold} and {max_threshold}' - ).format( - task_id=self.task_id, dag_id=self.dag_id, - description=meta_data.get("description"), sql=self.sql, - result=round(meta_data.get("result"), 2), - min_threshold=meta_data.get("min_threshold"), - max_threshold=meta_data.get("max_threshold") - ) - raise AirflowException(error_msg) - - self.log.info("Test %s Successful.", self.task_id) - - def push(self, meta_data): - """ - Optional: Send data check info and metadata to an external database. - Default functionality will log metadata. - """ - - info = "\n".join(["""{}: {}""".format(key, item) for key, item in meta_data.items()]) - self.log.info("Log from %s:\n%s", self.dag_id, info) - - def get_db_hook(self): - return BaseHook.get_hook(conn_id=self.conn_id) + def __init__(self, *args, **kwargs): + warnings.warn( + """This class is deprecated. + Please use `airflow.operators.sql.SQLValueCheckOperator`.""", + DeprecationWarning, stacklevel=2 + ) + super(ValueCheckOperator, self).__init__(*args, **kwargs) diff --git a/airflow/operators/check_operator.py b/airflow/operators/sql.py similarity index 50% copy from airflow/operators/check_operator.py copy to airflow/operators/sql.py index b6d3a18..3e53fbf 100644 --- a/airflow/operators/check_operator.py +++ b/airflow/operators/sql.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file @@ -16,19 +15,31 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -from builtins import str, zip -from typing import Optional, Any, Iterable, Dict, SupportsAbs +from distutils.util import strtobool +from typing import Iterable from airflow.exceptions import AirflowException from airflow.hooks.base_hook import BaseHook -from airflow.models import BaseOperator +from airflow.models import BaseOperator, SkipMixin from airflow.utils.decorators import apply_defaults - -class CheckOperator(BaseOperator): +ALLOWED_CONN_TYPE = { + "google_cloud_platform", + "jdbc", + "mssql", + "mysql", + "odbc", + "oracle", + "postgres", + "presto", + "sqlite", + "vertica", +} + + +class SQLCheckOperator(BaseOperator): """ - Performs checks against a db. The ``CheckOperator`` expects + Performs checks against a db. The ``SQLCheckOperator`` expects a sql query that will return a single row. Each value on that first row is evaluated using python ``bool`` casting. If any of the values return ``False`` the check is failed and errors out. @@ -62,36 +73,44 @@ class CheckOperator(BaseOperator): :type sql: str """ - template_fields = ('sql',) # type: Iterable[str] - template_ext = ('.hql', '.sql',) # type: Iterable[str] - ui_color = '#fff7e6' + template_fields = ("sql",) # type: Iterable[str] + template_ext = ( + ".hql", + ".sql", + ) # type: Iterable[str] + ui_color = "#fff7e6" @apply_defaults def __init__( - self, - sql, # type: str - conn_id=None, # type: Optional[str] - *args, - **kwargs + self, sql, conn_id=None, *args, **kwargs ): - super(CheckOperator, self).__init__(*args, **kwargs) + super(SQLCheckOperator, self).__init__(*args, **kwargs) self.conn_id = conn_id self.sql = sql def execute(self, context=None): - self.log.info('Executing SQL check: %s', self.sql) + self.log.info("Executing SQL check: %s", self.sql) records = self.get_db_hook().get_first(self.sql) - self.log.info('Record: %s', records) + self.log.info("Record: %s", records) if not records: raise AirflowException("The query returned None") elif not all([bool(r) for r in records]): - raise AirflowException("Test failed.\nQuery:\n{query}\nResults:\n{records!s}".format( - query=self.sql, records=records)) + raise AirflowException( + "Test failed.\nQuery:\n{query}\nResults:\n{records!s}".format( + query=self.sql, records=records + ) + ) self.log.info("Success.") def get_db_hook(self): + """ + Get the database hook for the connection. + + :return: the database hook object. + :rtype: DbApiHook + """ return BaseHook.get_hook(conn_id=self.conn_id) @@ -110,7 +129,7 @@ def _convert_to_float_if_possible(s): return ret -class ValueCheckOperator(BaseOperator): +class SQLValueCheckOperator(BaseOperator): """ Performs a simple value check using sql code. @@ -122,24 +141,27 @@ class ValueCheckOperator(BaseOperator): :type sql: str """ - __mapper_args__ = { - 'polymorphic_identity': 'ValueCheckOperator' - } - template_fields = ('sql', 'pass_value',) # type: Iterable[str] - template_ext = ('.hql', '.sql',) # type: Iterable[str] - ui_color = '#fff7e6' + __mapper_args__ = {"polymorphic_identity": "SQLValueCheckOperator"} + template_fields = ( + "sql", + "pass_value", + ) # type: Iterable[str] + template_ext = ( + ".hql", + ".sql", + ) # type: Iterable[str] + ui_color = "#fff7e6" @apply_defaults def __init__( - self, - sql, # type: str - pass_value, # type: Any - tolerance=None, # type: Any - conn_id=None, # type: Optional[str] - *args, - **kwargs - ): - super(ValueCheckOperator, self).__init__(*args, **kwargs) + self, + sql, + pass_value, + tolerance=None, + conn_id=None, + *args, + **kwargs): + super(SQLValueCheckOperator, self).__init__(*args, **kwargs) self.sql = sql self.conn_id = conn_id self.pass_value = str(pass_value) @@ -148,7 +170,7 @@ class ValueCheckOperator(BaseOperator): self.has_tolerance = self.tol is not None def execute(self, context=None): - self.log.info('Executing SQL check: %s', self.sql) + self.log.info("Executing SQL check: %s", self.sql) records = self.get_db_hook().get_first(self.sql) if not records: @@ -157,14 +179,16 @@ class ValueCheckOperator(BaseOperator): pass_value_conv = _convert_to_float_if_possible(self.pass_value) is_numeric_value_check = isinstance(pass_value_conv, float) - tolerance_pct_str = str(self.tol * 100) + '%' if self.has_tolerance else None - error_msg = ("Test failed.\nPass value:{pass_value_conv}\n" - "Tolerance:{tolerance_pct_str}\n" - "Query:\n{sql}\nResults:\n{records!s}").format( + tolerance_pct_str = str(self.tol * 100) + "%" if self.has_tolerance else None + error_msg = ( + "Test failed.\nPass value:{pass_value_conv}\n" + "Tolerance:{tolerance_pct_str}\n" + "Query:\n{sql}\nResults:\n{records!s}" + ).format( pass_value_conv=pass_value_conv, tolerance_pct_str=tolerance_pct_str, sql=self.sql, - records=records + records=records, ) if not is_numeric_value_check: @@ -173,7 +197,9 @@ class ValueCheckOperator(BaseOperator): try: numeric_records = self._to_float(records) except (ValueError, TypeError): - raise AirflowException("Converting a result to float failed.\n{}".format(error_msg)) + raise AirflowException( + "Converting a result to float failed.\n{}".format(error_msg) + ) tests = self._get_numeric_matches(numeric_records, pass_value_conv) else: tests = [] @@ -197,10 +223,16 @@ class ValueCheckOperator(BaseOperator): return [record == numeric_pass_value_conv for record in numeric_records] def get_db_hook(self): + """ + Get the database hook for the connection. + + :return: the database hook object. + :rtype: DbApiHook + """ return BaseHook.get_hook(conn_id=self.conn_id) -class IntervalCheckOperator(BaseOperator): +class SQLIntervalCheckOperator(BaseOperator): """ Checks that the values of metrics given as SQL expressions are within a certain tolerance of the ones from days_back before. @@ -229,38 +261,43 @@ class IntervalCheckOperator(BaseOperator): :type metrics_threshold: dict """ - __mapper_args__ = { - 'polymorphic_identity': 'IntervalCheckOperator' - } - template_fields = ('sql1', 'sql2') # type: Iterable[str] - template_ext = ('.hql', '.sql',) # type: Iterable[str] - ui_color = '#fff7e6' + __mapper_args__ = {"polymorphic_identity": "SQLIntervalCheckOperator"} + template_fields = ("sql1", "sql2") # type: Iterable[str] + template_ext = ( + ".hql", + ".sql", + ) # type: Iterable[str] + ui_color = "#fff7e6" ratio_formulas = { - 'max_over_min': lambda cur, ref: float(max(cur, ref)) / min(cur, ref), - 'relative_diff': lambda cur, ref: float(abs(cur - ref)) / ref, + "max_over_min": lambda cur, ref: float(max(cur, ref)) / min(cur, ref), + "relative_diff": lambda cur, ref: float(abs(cur - ref)) / ref, } @apply_defaults def __init__( self, - table, # type: str - metrics_thresholds, # type: Dict[str, int] - date_filter_column='ds', # type: Optional[str] - days_back=-7, # type: SupportsAbs[int] - ratio_formula='max_over_min', # type: Optional[str] - ignore_zero=True, # type: Optional[bool] - conn_id=None, # type: Optional[str] - *args, **kwargs + table, + metrics_thresholds, + date_filter_column="ds", + days_back=-7, + ratio_formula="max_over_min", + ignore_zero=True, + conn_id=None, + *args, + **kwargs ): - super(IntervalCheckOperator, self).__init__(*args, **kwargs) + super(SQLIntervalCheckOperator, self).__init__(*args, **kwargs) if ratio_formula not in self.ratio_formulas: - msg_template = "Invalid diff_method: {diff_method}. " \ - "Supported diff methods are: {diff_methods}" + msg_template = ( + "Invalid diff_method: {diff_method}. " + "Supported diff methods are: {diff_methods}" + ) raise AirflowException( - msg_template.format(diff_method=ratio_formula, - diff_methods=self.ratio_formulas) + msg_template.format( + diff_method=ratio_formula, diff_methods=self.ratio_formulas + ) ) self.ratio_formula = ratio_formula self.ignore_zero = ignore_zero @@ -270,7 +307,7 @@ class IntervalCheckOperator(BaseOperator): self.date_filter_column = date_filter_column self.days_back = -abs(days_back) self.conn_id = conn_id - sqlexp = ', '.join(self.metrics_sorted) + sqlexp = ", ".join(self.metrics_sorted) sqlt = "SELECT {sqlexp} FROM {table} WHERE {date_filter_column}=".format( sqlexp=sqlexp, table=table, date_filter_column=date_filter_column ) @@ -280,10 +317,10 @@ class IntervalCheckOperator(BaseOperator): def execute(self, context=None): hook = self.get_db_hook() - self.log.info('Using ratio formula: %s', self.ratio_formula) - self.log.info('Executing SQL check: %s', self.sql2) + self.log.info("Using ratio formula: %s", self.ratio_formula) + self.log.info("Executing SQL check: %s", self.sql2) row2 = hook.get_first(self.sql2) - self.log.info('Executing SQL check: %s', self.sql1) + self.log.info("Executing SQL check: %s", self.sql1) row1 = hook.get_first(self.sql1) if not row2: @@ -297,16 +334,18 @@ class IntervalCheckOperator(BaseOperator): ratios = {} test_results = {} - for m in self.metrics_sorted: - cur = current[m] - ref = reference[m] - threshold = self.metrics_thresholds[m] + for metric in self.metrics_sorted: + cur = current[metric] + ref = reference[metric] + threshold = self.metrics_thresholds[metric] if cur == 0 or ref == 0: - ratios[m] = None - test_results[m] = self.ignore_zero + ratios[metric] = None + test_results[metric] = self.ignore_zero else: - ratios[m] = self.ratio_formulas[self.ratio_formula](current[m], reference[m]) - test_results[m] = ratios[m] < threshold + ratios[metric] = self.ratio_formulas[self.ratio_formula]( + current[metric], reference[metric] + ) + test_results[metric] = ratios[metric] < threshold self.log.info( ( @@ -314,27 +353,49 @@ class IntervalCheckOperator(BaseOperator): "Past metric for %s: %s\n" "Ratio for %s: %s\n" "Threshold: %s\n" - ), m, cur, m, ref, m, ratios[m], threshold) + ), + metric, + cur, + metric, + ref, + metric, + ratios[metric], + threshold, + ) if not all(test_results.values()): failed_tests = [it[0] for it in test_results.items() if not it[1]] - j = len(failed_tests) - n = len(self.metrics_sorted) - self.log.warning("The following %s tests out of %s failed:", j, n) + self.log.warning( + "The following %s tests out of %s failed:", + len(failed_tests), + len(self.metrics_sorted), + ) for k in failed_tests: self.log.warning( - "'%s' check failed. %s is above %s", k, ratios[k], self.metrics_thresholds[k] + "'%s' check failed. %s is above %s", + k, + ratios[k], + self.metrics_thresholds[k], + ) + raise AirflowException( + "The following tests have failed:\n {0}".format( + ", ".join(sorted(failed_tests)) ) - raise AirflowException("The following tests have failed:\n {0}".format(", ".join( - sorted(failed_tests)))) + ) self.log.info("All tests have passed") def get_db_hook(self): + """ + Get the database hook for the connection. + + :return: the database hook object. + :rtype: DbApiHook + """ return BaseHook.get_hook(conn_id=self.conn_id) -class ThresholdCheckOperator(BaseOperator): +class SQLThresholdCheckOperator(BaseOperator): """ Performs a value check using sql code against a mininmum threshold and a maximum threshold. Thresholds can be in the form of a numeric @@ -352,19 +413,23 @@ class ThresholdCheckOperator(BaseOperator): :type max_threshold: numeric or str """ - template_fields = ('sql', 'min_threshold', 'max_threshold') # type: Iterable[str] - template_ext = ('.hql', '.sql',) # type: Iterable[str] + template_fields = ("sql", "min_threshold", "max_threshold") # type: Iterable[str] + template_ext = ( + ".hql", + ".sql", + ) # type: Iterable[str] @apply_defaults def __init__( self, - sql, # type: str - min_threshold, # type: Any - max_threshold, # type: Any - conn_id=None, # type: Optional[str] - *args, **kwargs + sql, + min_threshold, + max_threshold, + conn_id=None, + *args, + **kwargs ): - super(ThresholdCheckOperator, self).__init__(*args, **kwargs) + super(SQLThresholdCheckOperator, self).__init__(*args, **kwargs) self.sql = sql self.conn_id = conn_id self.min_threshold = _convert_to_float_if_possible(min_threshold) @@ -389,7 +454,7 @@ class ThresholdCheckOperator(BaseOperator): "task_id": self.task_id, "min_threshold": lower_bound, "max_threshold": upper_bound, - "within_threshold": lower_bound <= result <= upper_bound + "within_threshold": lower_bound <= result <= upper_bound, } self.push(meta_data) @@ -398,16 +463,17 @@ class ThresholdCheckOperator(BaseOperator): 'Threshold Check: "{task_id}" failed.\n' 'DAG: {dag_id}\nTask_id: {task_id}\n' 'Check description: {description}\n' - 'SQL: {sql}\n' - 'Result: {result} is not within thresholds ' - '{min_threshold} and {max_threshold}' - ).format( - task_id=self.task_id, dag_id=self.dag_id, - description=meta_data.get("description"), sql=self.sql, - result=round(meta_data.get("result"), 2), - min_threshold=meta_data.get("min_threshold"), - max_threshold=meta_data.get("max_threshold") - ) + "SQL: {sql}\n" + 'Result: {round} is not within thresholds ' + '{min} and {max}' + .format(task_id=meta_data.get("task_id"), + dag_id=self.dag_id, + description=meta_data.get("description"), + sql=self.sql, + round=round(meta_data.get("result"), 2), + min=meta_data.get("min_threshold"), + max=meta_data.get("max_threshold"), + )) raise AirflowException(error_msg) self.log.info("Test %s Successful.", self.task_id) @@ -418,8 +484,149 @@ class ThresholdCheckOperator(BaseOperator): Default functionality will log metadata. """ - info = "\n".join(["""{}: {}""".format(key, item) for key, item in meta_data.items()]) + info = "\n".join(["{key}: {item}".format(key=key, item=item) for key, item in meta_data.items()]) self.log.info("Log from %s:\n%s", self.dag_id, info) def get_db_hook(self): + """ + Returns DB hook + """ return BaseHook.get_hook(conn_id=self.conn_id) + + +class BranchSQLOperator(BaseOperator, SkipMixin): + """ + Executes sql code in a specific database + + :param sql: the sql code to be executed. (templated) + :type sql: Can receive a str representing a sql statement or reference to a template file. + Template reference are recognized by str ending in '.sql'. + Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1) + or string (true/y/yes/1/on/false/n/no/0/off). + :param follow_task_ids_if_true: task id or task ids to follow if query return true + :type follow_task_ids_if_true: str or list + :param follow_task_ids_if_false: task id or task ids to follow if query return true + :type follow_task_ids_if_false: str or list + :param conn_id: reference to a specific database + :type conn_id: str + :param database: name of database which overwrite defined one in connection + :param parameters: (optional) the parameters to render the SQL query with. + :type parameters: mapping or iterable + """ + + template_fields = ("sql",) + template_ext = (".sql",) + ui_color = "#a22034" + ui_fgcolor = "#F7F7F7" + + @apply_defaults + def __init__( + self, + sql, + follow_task_ids_if_true, + follow_task_ids_if_false, + conn_id="default_conn_id", + database=None, + parameters=None, + *args, + **kwargs + ): + super(BranchSQLOperator, self).__init__(*args, **kwargs) + self.conn_id = conn_id + self.sql = sql + self.parameters = parameters + self.follow_task_ids_if_true = follow_task_ids_if_true + self.follow_task_ids_if_false = follow_task_ids_if_false + self.database = database + self._hook = None + + def _get_hook(self): + self.log.debug("Get connection for %s", self.conn_id) + conn = BaseHook.get_connection(self.conn_id) + + if conn.conn_type not in ALLOWED_CONN_TYPE: + raise AirflowException( + "The connection type is not supported by BranchSQLOperator.\ + Supported connection types: {}".format(list(ALLOWED_CONN_TYPE)) + ) + + if not self._hook: + self._hook = conn.get_hook() + if self.database: + self._hook.schema = self.database + + return self._hook + + def execute(self, context): + # get supported hook + self._hook = self._get_hook() + + if self._hook is None: + raise AirflowException( + "Failed to establish connection to '%s'" % self.conn_id + ) + + if self.sql is None: + raise AirflowException("Expected 'sql' parameter is missing.") + + if self.follow_task_ids_if_true is None: + raise AirflowException( + "Expected 'follow_task_ids_if_true' paramter is missing." + ) + + if self.follow_task_ids_if_false is None: + raise AirflowException( + "Expected 'follow_task_ids_if_false' parameter is missing." + ) + + self.log.info( + "Executing: %s (with parameters %s) with connection: %s", + self.sql, + self.parameters, + self._hook, + ) + record = self._hook.get_first(self.sql, self.parameters) + if not record: + raise AirflowException( + "No rows returned from sql query. Operator expected True or False return value." + ) + + if isinstance(record, list): + if isinstance(record[0], list): + query_result = record[0][0] + else: + query_result = record[0] + elif isinstance(record, tuple): + query_result = record[0] + else: + query_result = record + + self.log.info("Query returns %s, type '%s'", query_result, type(query_result)) + + follow_branch = None + try: + if isinstance(query_result, bool): + if query_result: + follow_branch = self.follow_task_ids_if_true + elif isinstance(query_result, str): + # return result is not Boolean, try to convert from String to Boolean + if bool(strtobool(query_result)): + follow_branch = self.follow_task_ids_if_true + elif isinstance(query_result, int): + if bool(query_result): + follow_branch = self.follow_task_ids_if_true + else: + raise AirflowException( + "Unexpected query return result '%s' type '%s'" + % (query_result, type(query_result)) + ) + + if follow_branch is None: + follow_branch = self.follow_task_ids_if_false + except ValueError: + raise AirflowException( + "Unexpected query return result '%s' type '%s'" + % (query_result, type(query_result)) + ) + + self.skip_all_except(context["ti"], follow_branch) diff --git a/airflow/operators/sql_branch_operator.py b/airflow/operators/sql_branch_operator.py index 072c40c..b911e34 100644 --- a/airflow/operators/sql_branch_operator.py +++ b/airflow/operators/sql_branch_operator.py @@ -14,160 +14,22 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""This module is deprecated. Please use `airflow.operators.sql`.""" +import warnings -from distutils.util import strtobool +from airflow.operators.sql import BranchSQLOperator -from airflow.exceptions import AirflowException -from airflow.hooks.base_hook import BaseHook -from airflow.models import BaseOperator, SkipMixin -from airflow.utils.decorators import apply_defaults -ALLOWED_CONN_TYPE = { - "google_cloud_platform", - "jdbc", - "mssql", - "mysql", - "odbc", - "oracle", - "postgres", - "presto", - "sqlite", - "vertica", -} - - -class BranchSqlOperator(BaseOperator, SkipMixin): +class BranchSqlOperator(BranchSQLOperator): """ - Executes sql code in a specific database - - :param sql: the sql code to be executed. (templated) - :type sql: Can receive a str representing a sql statement or reference to a template file. - Template reference are recognized by str ending in '.sql'. - Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1) - or string (true/y/yes/1/on/false/n/no/0/off). - :param follow_task_ids_if_true: task id or task ids to follow if query return true - :type follow_task_ids_if_true: str or list - :param follow_task_ids_if_false: task id or task ids to follow if query return true - :type follow_task_ids_if_false: str or list - :param conn_id: reference to a specific database - :type conn_id: str - :param database: name of database which overwrite defined one in connection - :param parameters: (optional) the parameters to render the SQL query with. - :type parameters: mapping or iterable + This class is deprecated. + Please use `airflow.operators.sql.BranchSQLOperator`. """ - template_fields = ("sql",) - template_ext = (".sql",) - ui_color = "#a22034" - ui_fgcolor = "#F7F7F7" - - @apply_defaults - def __init__( - self, - sql, - follow_task_ids_if_true, - follow_task_ids_if_false, - conn_id="default_conn_id", - database=None, - parameters=None, - *args, - **kwargs): - super(BranchSqlOperator, self).__init__(*args, **kwargs) - self.conn_id = conn_id - self.sql = sql - self.parameters = parameters - self.follow_task_ids_if_true = follow_task_ids_if_true - self.follow_task_ids_if_false = follow_task_ids_if_false - self.database = database - self._hook = None - - def _get_hook(self): - self.log.debug("Get connection for %s", self.conn_id) - conn = BaseHook.get_connection(self.conn_id) - - if conn.conn_type not in ALLOWED_CONN_TYPE: - raise AirflowException( - "The connection type is not supported by BranchSqlOperator. " - + "Supported connection types: {}".format(list(ALLOWED_CONN_TYPE)) - ) - - if not self._hook: - self._hook = conn.get_hook() - if self.database: - self._hook.schema = self.database - - return self._hook - - def execute(self, context): - # get supported hook - self._hook = self._get_hook() - - if self._hook is None: - raise AirflowException( - "Failed to establish connection to '%s'" % self.conn_id - ) - - if self.sql is None: - raise AirflowException("Expected 'sql' parameter is missing.") - - if self.follow_task_ids_if_true is None: - raise AirflowException( - "Expected 'follow_task_ids_if_true' paramter is missing." - ) - - if self.follow_task_ids_if_false is None: - raise AirflowException( - "Expected 'follow_task_ids_if_false' parameter is missing." - ) - - self.log.info( - "Executing: %s (with parameters %s) with connection: %s", - self.sql, - self.parameters, - self._hook, + def __init__(self, *args, **kwargs): + warnings.warn( + """This class is deprecated. + Please use `airflow.operators.sql.BranchSQLOperator`.""", + DeprecationWarning, stacklevel=2 ) - record = self._hook.get_first(self.sql, self.parameters) - if not record: - raise AirflowException( - "No rows returned from sql query. Operator expected True or False return value." - ) - - if isinstance(record, list): - if isinstance(record[0], list): - query_result = record[0][0] - else: - query_result = record[0] - elif isinstance(record, tuple): - query_result = record[0] - else: - query_result = record - - self.log.info("Query returns %s, type '%s'", query_result, type(query_result)) - - follow_branch = None - try: - if isinstance(query_result, bool): - if query_result: - follow_branch = self.follow_task_ids_if_true - elif isinstance(query_result, str): - # return result is not Boolean, try to convert from String to Boolean - if bool(strtobool(query_result)): - follow_branch = self.follow_task_ids_if_true - elif isinstance(query_result, int): - if bool(query_result): - follow_branch = self.follow_task_ids_if_true - else: - raise AirflowException( - "Unexpected query return result '%s' type '%s'" - % (query_result, type(query_result)) - ) - - if follow_branch is None: - follow_branch = self.follow_task_ids_if_false - except ValueError: - raise AirflowException( - "Unexpected query return result '%s' type '%s'" - % (query_result, type(query_result)) - ) - - self.skip_all_except(context["ti"], follow_branch) + super(BranchSqlOperator, self).__init__(*args, **kwargs) diff --git a/docs/operators-and-hooks-ref.rst b/docs/operators-and-hooks-ref.rst index 55176f8..1fd11c3 100644 --- a/docs/operators-and-hooks-ref.rst +++ b/docs/operators-and-hooks-ref.rst @@ -57,10 +57,6 @@ Fundamentals * - :mod:`airflow.operators.branch_operator` - - - * - :mod:`airflow.operators.check_operator` - - - * - :mod:`airflow.operators.dagrun_operator` - @@ -76,7 +72,7 @@ Fundamentals * - :mod:`airflow.operators.subdag_operator` - - * - :mod:`airflow.operators.sql_branch_operator` + * - :mod:`airflow.operators.sql` - **Sensors:** @@ -90,9 +86,6 @@ Fundamentals * - :mod:`airflow.sensors.weekday_sensor` - - * - :mod:`airflow.sensors.external_task_sensor` - - :doc:`How to use <howto/operator/external_task_sensor>` - * - :mod:`airflow.sensors.sql_sensor` - @@ -470,7 +463,7 @@ These integrations allow you to copy data from/to Amazon Web Services. * - `Amazon Simple Storage Service (S3) <https://aws.amazon.com/s3/>`__ - `Google Cloud Storage (GCS) <https://cloud.google.com/gcs/>`__ - - :doc:`How to use <howto/operator/gcp/cloud_storage_transfer_service>` + - - :mod:`airflow.contrib.operators.s3_to_gcs_operator`, :mod:`airflow.gcp.operators.cloud_storage_transfer_service` @@ -551,7 +544,7 @@ These integrations allow you to perform various operations within the Google Clo - Sensors * - `AutoML <https://cloud.google.com/automl/>`__ - - :doc:`How to use <howto/operator/gcp/automl>` + - - :mod:`airflow.gcp.hooks.automl` - :mod:`airflow.gcp.operators.automl` - @@ -563,7 +556,7 @@ These integrations allow you to perform various operations within the Google Clo - :mod:`airflow.gcp.sensors.bigquery` * - `BigQuery Data Transfer Service <https://cloud.google.com/bigquery/transfer/>`__ - - :doc:`How to use <howto/operator/gcp/bigquery_dts>` + - - :mod:`airflow.gcp.hooks.bigquery_dts` - :mod:`airflow.gcp.operators.bigquery_dts` - :mod:`airflow.gcp.sensors.bigquery_dts` @@ -611,7 +604,7 @@ These integrations allow you to perform various operations within the Google Clo - * - `Cloud Functions <https://cloud.google.com/functions/>`__ - - :doc:`How to use <howto/operator/gcp/functions>` + - :doc:`How to use <howto/operator/gcp/function>` - :mod:`airflow.gcp.hooks.functions` - :mod:`airflow.gcp.operators.functions` - @@ -635,7 +628,7 @@ These integrations allow you to perform various operations within the Google Clo - * - `Cloud Memorystore <https://cloud.google.com/memorystore/>`__ - - :doc:`How to use <howto/operator/gcp/cloud_memorystore>` + - - :mod:`airflow.gcp.hooks.cloud_memorystore` - :mod:`airflow.gcp.operators.cloud_memorystore` - @@ -677,7 +670,7 @@ These integrations allow you to perform various operations within the Google Clo - :mod:`airflow.gcp.sensors.gcs` * - `Storage Transfer Service <https://cloud.google.com/storage/transfer/>`__ - - :doc:`How to use <howto/operator/gcp/cloud_storage_transfer_service>` + - - :mod:`airflow.gcp.hooks.cloud_storage_transfer_service` - :mod:`airflow.gcp.operators.cloud_storage_transfer_service` - :mod:`airflow.gcp.sensors.cloud_storage_transfer_service` @@ -701,7 +694,7 @@ These integrations allow you to perform various operations within the Google Clo - * - `Cloud Video Intelligence <https://cloud.google.com/video_intelligence/>`__ - - :doc:`How to use <howto/operator/gcp/video_intelligence>` + - :doc:`How to use <howto/operator/gcp/video>` - :mod:`airflow.gcp.hooks.video_intelligence` - :mod:`airflow.gcp.operators.video_intelligence` - @@ -741,7 +734,7 @@ These integrations allow you to copy data from/to Google Cloud Platform. * - `Amazon Simple Storage Service (S3) <https://aws.amazon.com/s3/>`__ - `Google Cloud Storage (GCS) <https://cloud.google.com/gcs/>`__ - - :doc:`How to use <howto/operator/gcp/cloud_storage_transfer_service>` + - - :mod:`airflow.contrib.operators.s3_to_gcs_operator`, :mod:`airflow.gcp.operators.cloud_storage_transfer_service` @@ -772,8 +765,7 @@ These integrations allow you to copy data from/to Google Cloud Platform. * - `Google Cloud Storage (GCS) <https://cloud.google.com/gcs/>`__ - `Google Cloud Storage (GCS) <https://cloud.google.com/gcs/>`__ - - :doc:`How to use <howto/operator/gcp/gcs_to_gcs>`, - :doc:`How to use <howto/operator/gcp/cloud_storage_transfer_service>` + - - :mod:`airflow.operators.gcs_to_gcs`, :mod:`airflow.gcp.operators.cloud_storage_transfer_service` @@ -1037,7 +1029,7 @@ These integrations allow you to perform various operations using various softwar - :mod:`airflow.contrib.sensors.bash_sensor` * - `Kubernetes <https://kubernetes.io/>`__ - - :doc:`How to use <howto/operator/kubernetes>` + - - - :mod:`airflow.contrib.operators.kubernetes_pod_operator` - diff --git a/tests/api/common/experimental/test_pool.py b/tests/api/common/experimental/test_pool.py index 29c7105..97c970b 100644 --- a/tests/api/common/experimental/test_pool.py +++ b/tests/api/common/experimental/test_pool.py @@ -56,17 +56,18 @@ class TestPool(unittest.TestCase): self.assertEqual(pool.pool, self.pools[0].pool) def test_get_pool_non_existing(self): - self.assertRaisesRegexp(PoolNotFound, - "^Pool 'test' doesn't exist$", - pool_api.get_pool, - name='test') + six.assertRaisesRegex(self, PoolNotFound, + "^Pool 'test' doesn't exist$", + pool_api.get_pool, + name='test') def test_get_pool_bad_name(self): for name in ('', ' '): - self.assertRaisesRegexp(AirflowBadRequest, - "^Pool name shouldn't be empty$", - pool_api.get_pool, - name=name) + six.assertRaisesRegex(self, + AirflowBadRequest, + "^Pool name shouldn't be empty$", + pool_api.get_pool, + name=name) def test_get_pools(self): pools = sorted(pool_api.get_pools(), @@ -96,20 +97,21 @@ class TestPool(unittest.TestCase): def test_create_pool_bad_name(self): for name in ('', ' '): - self.assertRaisesRegexp(AirflowBadRequest, - "^Pool name shouldn't be empty$", - pool_api.create_pool, - name=name, - slots=5, - description='') + six.assertRaisesRegex(self, + AirflowBadRequest, + "^Pool name shouldn't be empty$", + pool_api.create_pool, + name=name, + slots=5, + description='') def test_create_pool_bad_slots(self): - self.assertRaisesRegexp(AirflowBadRequest, - "^Bad value for `slots`: foo$", - pool_api.create_pool, - name='foo', - slots='foo', - description='') + six.assertRaisesRegex(self, AirflowBadRequest, + "^Bad value for `slots`: foo$", + pool_api.create_pool, + name='foo', + slots='foo', + description='') def test_delete_pool(self): pool = pool_api.delete_pool(name=self.pools[-1].pool) @@ -118,21 +120,23 @@ class TestPool(unittest.TestCase): self.assertEqual(session.query(models.Pool).count(), self.TOTAL_POOL_COUNT - 1) def test_delete_pool_non_existing(self): - self.assertRaisesRegexp(pool_api.PoolNotFound, - "^Pool 'test' doesn't exist$", - pool_api.delete_pool, - name='test') + six.assertRaisesRegex(self, pool_api.PoolNotFound, + "^Pool 'test' doesn't exist$", + pool_api.delete_pool, + name='test') def test_delete_pool_bad_name(self): for name in ('', ' '): - self.assertRaisesRegexp(AirflowBadRequest, - "^Pool name shouldn't be empty$", - pool_api.delete_pool, - name=name) + six.assertRaisesRegex(self, + AirflowBadRequest, + "^Pool name shouldn't be empty$", + pool_api.delete_pool, + name=name) def test_delete_default_pool_not_allowed(self): - with self.assertRaisesRegex(AirflowBadRequest, - "^default_pool cannot be deleted$"): + with six.assertRaisesRegex(self, + AirflowBadRequest, + "^default_pool cannot be deleted$"): pool_api.delete_pool(Pool.DEFAULT_POOL_NAME) diff --git a/tests/contrib/hooks/test_discord_webhook_hook.py b/tests/contrib/hooks/test_discord_webhook_hook.py index d0c9001..384b7f3 100644 --- a/tests/contrib/hooks/test_discord_webhook_hook.py +++ b/tests/contrib/hooks/test_discord_webhook_hook.py @@ -20,6 +20,8 @@ import json import unittest +import six + from airflow import AirflowException from airflow.models import Connection from airflow.utils import db @@ -73,7 +75,7 @@ class TestDiscordWebhookHook(unittest.TestCase): # When/Then expected_message = 'Expected Discord webhook endpoint in the form of' - with self.assertRaisesRegexp(AirflowException, expected_message): + with six.assertRaisesRegex(self, AirflowException, expected_message): DiscordWebhookHook(webhook_endpoint=provided_endpoint) def test_get_webhook_endpoint_conn_id(self): @@ -107,7 +109,7 @@ class TestDiscordWebhookHook(unittest.TestCase): # When/Then expected_message = 'Discord message length must be 2000 or fewer characters' - with self.assertRaisesRegexp(AirflowException, expected_message): + with six.assertRaisesRegex(self, AirflowException, expected_message): hook._build_discord_payload() diff --git a/tests/contrib/operators/test_databricks_operator.py b/tests/contrib/operators/test_databricks_operator.py index 9a7b6ec..6e59408 100644 --- a/tests/contrib/operators/test_databricks_operator.py +++ b/tests/contrib/operators/test_databricks_operator.py @@ -21,6 +21,8 @@ import unittest from datetime import datetime +import six + from airflow.contrib.hooks.databricks_hook import RunState import airflow.contrib.operators.databricks_operator as databricks_operator from airflow.contrib.operators.databricks_operator import DatabricksSubmitRunOperator @@ -180,7 +182,7 @@ class DatabricksSubmitRunOperatorTest(unittest.TestCase): # Looks a bit weird since we have to escape regex reserved symbols. exception_message = r'Type \<(type|class) \'datetime.datetime\'\> used ' + \ r'for parameter json\[test\] is not a number or a string' - with self.assertRaisesRegexp(AirflowException, exception_message): + with six.assertRaisesRegex(self, AirflowException, exception_message): DatabricksSubmitRunOperator(task_id=TASK_ID, json=json) @mock.patch('airflow.contrib.operators.databricks_operator.DatabricksHook') @@ -347,7 +349,7 @@ class DatabricksRunNowOperatorTest(unittest.TestCase): # Looks a bit weird since we have to escape regex reserved symbols. exception_message = r'Type \<(type|class) \'datetime.datetime\'\> used ' + \ r'for parameter json\[test\] is not a number or a string' - with self.assertRaisesRegexp(AirflowException, exception_message): + with six.assertRaisesRegex(self, AirflowException, exception_message): DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=json) @mock.patch('airflow.contrib.operators.databricks_operator.DatabricksHook') diff --git a/tests/contrib/operators/test_gcs_to_gcs_operator.py b/tests/contrib/operators/test_gcs_to_gcs_operator.py index f9085e2..622fa8f 100644 --- a/tests/contrib/operators/test_gcs_to_gcs_operator.py +++ b/tests/contrib/operators/test_gcs_to_gcs_operator.py @@ -20,6 +20,8 @@ import unittest from datetime import datetime +import six + from airflow.contrib.operators.gcs_to_gcs import \ GoogleCloudStorageToGoogleCloudStorageOperator, WILDCARD from airflow.exceptions import AirflowException @@ -290,7 +292,7 @@ class GoogleCloudStorageToCloudStorageOperatorTest(unittest.TestCase): error_msg = "Only one wildcard '[*]' is allowed in source_object parameter. " \ "Found {}".format(total_wildcards) - with self.assertRaisesRegexp(AirflowException, error_msg): + with six.assertRaisesRegex(self, AirflowException, error_msg): operator.execute(None) @mock.patch('airflow.contrib.operators.gcs_to_gcs.GoogleCloudStorageHook') diff --git a/tests/contrib/operators/test_qubole_check_operator.py b/tests/contrib/operators/test_qubole_check_operator.py index b1692d8..f6d875a 100644 --- a/tests/contrib/operators/test_qubole_check_operator.py +++ b/tests/contrib/operators/test_qubole_check_operator.py @@ -19,6 +19,9 @@ # import unittest from datetime import datetime + +import six + from airflow.models import DAG from airflow.exceptions import AirflowException from airflow.contrib.operators.qubole_check_operator import QuboleValueCheckOperator @@ -88,8 +91,8 @@ class QuboleValueCheckOperatorTest(unittest.TestCase): operator = self.__construct_operator('select value from tab1 limit 1;', 5, 1) - with self.assertRaisesRegexp(AirflowException, - 'Qubole Command Id: ' + str(mock_cmd.id)): + with six.assertRaisesRegex(self, AirflowException, + 'Qubole Command Id: ' + str(mock_cmd.id)): operator.execute() mock_cmd.is_success.assert_called_with(mock_cmd.status) diff --git a/tests/contrib/operators/test_sftp_operator.py b/tests/contrib/operators/test_sftp_operator.py index 24db36e..fe478b9 100644 --- a/tests/contrib/operators/test_sftp_operator.py +++ b/tests/contrib/operators/test_sftp_operator.py @@ -362,10 +362,8 @@ class SFTPOperatorTest(unittest.TestCase): os.environ['AIRFLOW_CONN_' + conn_id.upper()] = "ssh://test_id@localhost" # Exception should be raised if neither ssh_hook nor ssh_conn_id is provided - if six.PY2: - self.assertRaisesRegex = self.assertRaisesRegexp - with self.assertRaisesRegex(AirflowException, - "Cannot operate without ssh_hook or ssh_conn_id."): + with six.assertRaisesRegex(self, AirflowException, + "Cannot operate without ssh_hook or ssh_conn_id."): task_0 = SFTPOperator( task_id="test_sftp", local_filepath=self.test_local_filepath, diff --git a/tests/contrib/operators/test_ssh_operator.py b/tests/contrib/operators/test_ssh_operator.py index f2294ba..1413050 100644 --- a/tests/contrib/operators/test_ssh_operator.py +++ b/tests/contrib/operators/test_ssh_operator.py @@ -152,10 +152,8 @@ class SSHOperatorTest(TestCase): os.environ['AIRFLOW_CONN_' + conn_id.upper()] = "ssh://test_id@localhost" # Exception should be raised if neither ssh_hook nor ssh_conn_id is provided - if six.PY2: - self.assertRaisesRegex = self.assertRaisesRegexp - with self.assertRaisesRegex(AirflowException, - "Cannot operate without ssh_hook or ssh_conn_id."): + with six.assertRaisesRegex(self, AirflowException, + "Cannot operate without ssh_hook or ssh_conn_id."): task_0 = SSHOperator(task_id="test", command=COMMAND, timeout=TIMEOUT, dag=self.dag) task_0.execute(None) diff --git a/tests/contrib/operators/test_winrm_operator.py b/tests/contrib/operators/test_winrm_operator.py index 27792a0..c6b26f7 100644 --- a/tests/contrib/operators/test_winrm_operator.py +++ b/tests/contrib/operators/test_winrm_operator.py @@ -20,6 +20,8 @@ import mock import unittest +import six + from airflow.contrib.operators.winrm_operator import WinRMOperator from airflow.exceptions import AirflowException @@ -30,7 +32,7 @@ class WinRMOperatorTest(unittest.TestCase): winrm_hook=None, ssh_conn_id=None) exception_msg = "Cannot operate without winrm_hook or ssh_conn_id." - with self.assertRaisesRegexp(AirflowException, exception_msg): + with six.assertRaisesRegex(self, AirflowException, exception_msg): op.execute(None) @mock.patch('airflow.contrib.operators.winrm_operator.WinRMHook') @@ -41,7 +43,7 @@ class WinRMOperatorTest(unittest.TestCase): command=None ) exception_msg = "No command specified so nothing to execute here." - with self.assertRaisesRegexp(AirflowException, exception_msg): + with six.assertRaisesRegex(self, AirflowException, exception_msg): op.execute(None) diff --git a/tests/contrib/sensors/test_weekday_sensor.py b/tests/contrib/sensors/test_weekday_sensor.py index 016f71c..d822a69 100644 --- a/tests/contrib/sensors/test_weekday_sensor.py +++ b/tests/contrib/sensors/test_weekday_sensor.py @@ -19,6 +19,9 @@ # import unittest + +import six + from airflow import DAG, models from airflow.contrib.sensors.weekday_sensor import DayOfWeekSensor from airflow.contrib.utils.weekday import WeekDay @@ -78,9 +81,8 @@ class DayOfWeekSensorTests(unittest.TestCase): def test_invalid_weekday_number(self): invalid_week_day = 'Thsday' - with self.assertRaisesRegexp(AttributeError, - 'Invalid Week Day passed: "{}"'.format( - invalid_week_day)): + with six.assertRaisesRegex(self, AttributeError, + 'Invalid Week Day passed: "{}"'.format(invalid_week_day)): DayOfWeekSensor( task_id='weekday_sensor_invalid_weekday_num', week_day=invalid_week_day, @@ -139,11 +141,11 @@ class DayOfWeekSensorTests(unittest.TestCase): def test_weekday_sensor_with_invalid_type(self): invalid_week_day = ['Thsday'] - with self.assertRaisesRegexp(TypeError, - 'Unsupported Type for week_day parameter:' - ' {}. It should be one of str, set or ' - 'Weekday enum type'.format(type(invalid_week_day)) - ): + with six.assertRaisesRegex(self, TypeError, + 'Unsupported Type for week_day parameter:' + ' {}. It should be one of str, set or ' + 'Weekday enum type'.format(type(invalid_week_day)) + ): DayOfWeekSensor( task_id='weekday_sensor_check_true', week_day=invalid_week_day, diff --git a/tests/contrib/utils/test_mlengine_operator_utils.py b/tests/contrib/utils/test_mlengine_operator_utils.py index 28efef5..53e1323 100644 --- a/tests/contrib/utils/test_mlengine_operator_utils.py +++ b/tests/contrib/utils/test_mlengine_operator_utils.py @@ -22,6 +22,8 @@ from __future__ import print_function import datetime import unittest +import six + from airflow import DAG from airflow.contrib.utils import mlengine_operator_utils from airflow.exceptions import AirflowException @@ -152,25 +154,25 @@ class CreateEvaluateOpsTest(unittest.TestCase): 'dag': dag, } - with self.assertRaisesRegexp(AirflowException, 'Missing model origin'): + with six.assertRaisesRegex(self, AirflowException, 'Missing model origin'): mlengine_operator_utils.create_evaluate_ops(**other_params_but_models) - with self.assertRaisesRegexp(AirflowException, 'Ambiguous model origin'): + with six.assertRaisesRegex(self, AirflowException, 'Ambiguous model origin'): mlengine_operator_utils.create_evaluate_ops(model_uri='abc', model_name='cde', **other_params_but_models) - with self.assertRaisesRegexp(AirflowException, 'Ambiguous model origin'): + with six.assertRaisesRegex(self, AirflowException, 'Ambiguous model origin'): mlengine_operator_utils.create_evaluate_ops(model_uri='abc', version_name='vvv', **other_params_but_models) - with self.assertRaisesRegexp(AirflowException, - '`metric_fn` param must be callable'): + with six.assertRaisesRegex(self, AirflowException, + '`metric_fn` param must be callable'): params = other_params_but_models.copy() params['metric_fn_and_keys'] = (None, ['abc']) mlengine_operator_utils.create_evaluate_ops(model_uri='gs://blah', **params) - with self.assertRaisesRegexp(AirflowException, - '`validate_fn` param must be callable'): + with six.assertRaisesRegex(self, AirflowException, + '`validate_fn` param must be callable'): params = other_params_but_models.copy() params['validate_fn'] = None mlengine_operator_utils.create_evaluate_ops(model_uri='gs://blah', **params) diff --git a/tests/jobs/test_backfill_job.py b/tests/jobs/test_backfill_job.py index e522556..d272fcf 100644 --- a/tests/jobs/test_backfill_job.py +++ b/tests/jobs/test_backfill_job.py @@ -25,6 +25,7 @@ import threading import unittest import pytest +import six import sqlalchemy from parameterized import parameterized @@ -790,7 +791,8 @@ class BackfillJobTest(unittest.TestCase): run_date = DEFAULT_DATE + datetime.timedelta(days=5) # backfill should deadlock - self.assertRaisesRegexp( + six.assertRaisesRegex( + self, AirflowException, 'BackfillJob is deadlocked', BackfillJob(dag=dag, start_date=run_date, end_date=run_date).run) @@ -890,7 +892,7 @@ class BackfillJobTest(unittest.TestCase): # raises backwards expected_msg = 'You cannot backfill backwards because one or more tasks depend_on_past: {}'.format( 'test_dop_task') - with self.assertRaisesRegexp(AirflowException, expected_msg): + with six.assertRaisesRegex(self, AirflowException, expected_msg): executor = MockExecutor() job = BackfillJob(dag=dag, executor=executor, @@ -1166,7 +1168,8 @@ class BackfillJobTest(unittest.TestCase): start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, executor=executor) - self.assertRaisesRegexp( + six.assertRaisesRegex( + self, AirflowException, 'Some task instances failed', job.run) diff --git a/tests/kubernetes/test_worker_configuration.py b/tests/kubernetes/test_worker_configuration.py index 1b15d98..8378f9f 100644 --- a/tests/kubernetes/test_worker_configuration.py +++ b/tests/kubernetes/test_worker_configuration.py @@ -19,6 +19,9 @@ import unittest import uuid from datetime import datetime + +import six + from tests.compat import mock from tests.test_utils.config import conf_vars try: @@ -99,10 +102,10 @@ class TestKubernetesWorkerConfiguration(unittest.TestCase): ('kubernetes', 'kube_client_request_args'): '{"_request_timeout" : [60,360]}', }) def test_worker_configuration_auth_both_ssh_and_user(self): - with self.assertRaisesRegexp(AirflowConfigException, - 'either `git_user` and `git_password`.*' - 'or `git_ssh_key_secret_name`.*' - 'but not both$'): + with six.assertRaisesRegex(self, AirflowConfigException, + 'either `git_user` and `git_password`.*' + 'or `git_ssh_key_secret_name`.*' + 'but not both$'): KubeConfig() def test_worker_with_subpaths(self): diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index 2d00c59..ea65823 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -19,6 +19,9 @@ import datetime import unittest + +import six + from tests.compat import mock import uuid @@ -196,7 +199,7 @@ class BaseOperatorTest(unittest.TestCase): re = "('ClassWithCustomAttributes' object|ClassWithCustomAttributes instance) " \ "has no attribute 'missing_field'" - with self.assertRaisesRegexp(AttributeError, re): + with six.assertRaisesRegex(self, AttributeError, re): task.render_template(ClassWithCustomAttributes(template_fields=["missing_field"]), {}) def test_jinja_invalid_expression_is_just_propagated(self): diff --git a/tests/operators/test_check_operator.py b/tests/operators/test_check_operator.py deleted file mode 100644 index 22523a4..0000000 --- a/tests/operators/test_check_operator.py +++ /dev/null @@ -1,327 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import six -import unittest -from datetime import datetime - -from airflow.exceptions import AirflowException -from airflow.models import DAG -from airflow.operators.check_operator import ( - CheckOperator, IntervalCheckOperator, ThresholdCheckOperator, ValueCheckOperator, -) -from tests.compat import mock - - -class TestCheckOperator(unittest.TestCase): - - @mock.patch.object(CheckOperator, 'get_db_hook') - def test_execute_no_records(self, mock_get_db_hook): - mock_get_db_hook.return_value.get_first.return_value = [] - - with self.assertRaises(AirflowException): - CheckOperator(sql='sql').execute() - - @mock.patch.object(CheckOperator, 'get_db_hook') - def test_execute_not_all_records_are_true(self, mock_get_db_hook): - mock_get_db_hook.return_value.get_first.return_value = ["data", ""] - - with self.assertRaises(AirflowException): - CheckOperator(sql='sql').execute() - - -class TestValueCheckOperator(unittest.TestCase): - - def setUp(self): - self.task_id = 'test_task' - self.conn_id = 'default_conn' - - def _construct_operator(self, sql, pass_value, tolerance=None): - dag = DAG('test_dag', start_date=datetime(2017, 1, 1)) - - return ValueCheckOperator( - dag=dag, - task_id=self.task_id, - conn_id=self.conn_id, - sql=sql, - pass_value=pass_value, - tolerance=tolerance) - - def test_pass_value_template_string(self): - pass_value_str = "2018-03-22" - operator = self._construct_operator('select date from tab1;', "{{ ds }}") - - operator.render_template_fields({'ds': pass_value_str}) - - self.assertEqual(operator.task_id, self.task_id) - self.assertEqual(operator.pass_value, pass_value_str) - - def test_pass_value_template_string_float(self): - pass_value_float = 4.0 - operator = self._construct_operator('select date from tab1;', pass_value_float) - - operator.render_template_fields({}) - - self.assertEqual(operator.task_id, self.task_id) - self.assertEqual(operator.pass_value, str(pass_value_float)) - - @mock.patch.object(ValueCheckOperator, 'get_db_hook') - def test_execute_pass(self, mock_get_db_hook): - mock_hook = mock.Mock() - mock_hook.get_first.return_value = [10] - mock_get_db_hook.return_value = mock_hook - sql = 'select value from tab1 limit 1;' - operator = self._construct_operator(sql, 5, 1) - - operator.execute(None) - - mock_hook.get_first.assert_called_with(sql) - - @mock.patch.object(ValueCheckOperator, 'get_db_hook') - def test_execute_fail(self, mock_get_db_hook): - mock_hook = mock.Mock() - mock_hook.get_first.return_value = [11] - mock_get_db_hook.return_value = mock_hook - - operator = self._construct_operator('select value from tab1 limit 1;', 5, 1) - - with self.assertRaisesRegexp(AirflowException, 'Tolerance:100.0%'): - operator.execute() - - -class IntervalCheckOperatorTest(unittest.TestCase): - - def _construct_operator(self, table, metric_thresholds, - ratio_formula, ignore_zero): - return IntervalCheckOperator( - task_id='test_task', - table=table, - metrics_thresholds=metric_thresholds, - ratio_formula=ratio_formula, - ignore_zero=ignore_zero, - ) - - def test_invalid_ratio_formula(self): - with self.assertRaisesRegexp(AirflowException, 'Invalid diff_method'): - self._construct_operator( - table='test_table', - metric_thresholds={ - 'f1': 1, - }, - ratio_formula='abs', - ignore_zero=False, - ) - - @mock.patch.object(IntervalCheckOperator, 'get_db_hook') - def test_execute_not_ignore_zero(self, mock_get_db_hook): - mock_hook = mock.Mock() - mock_hook.get_first.return_value = [0] - mock_get_db_hook.return_value = mock_hook - - operator = self._construct_operator( - table='test_table', - metric_thresholds={ - 'f1': 1, - }, - ratio_formula='max_over_min', - ignore_zero=False, - ) - - with self.assertRaises(AirflowException): - operator.execute() - - @mock.patch.object(IntervalCheckOperator, 'get_db_hook') - def test_execute_ignore_zero(self, mock_get_db_hook): - mock_hook = mock.Mock() - mock_hook.get_first.return_value = [0] - mock_get_db_hook.return_value = mock_hook - - operator = self._construct_operator( - table='test_table', - metric_thresholds={ - 'f1': 1, - }, - ratio_formula='max_over_min', - ignore_zero=True, - ) - - operator.execute() - - @mock.patch.object(IntervalCheckOperator, 'get_db_hook') - def test_execute_min_max(self, mock_get_db_hook): - mock_hook = mock.Mock() - - def returned_row(): - rows = [ - [2, 2, 2, 2], # reference - [1, 1, 1, 1], # current - ] - - for r in rows: - yield r - - mock_hook.get_first.side_effect = returned_row() - mock_get_db_hook.return_value = mock_hook - - operator = self._construct_operator( - table='test_table', - metric_thresholds={ - 'f0': 1.0, - 'f1': 1.5, - 'f2': 2.0, - 'f3': 2.5, - }, - ratio_formula='max_over_min', - ignore_zero=True, - ) - - with self.assertRaisesRegexp(AirflowException, "f0, f1, f2"): - operator.execute() - - @mock.patch.object(IntervalCheckOperator, 'get_db_hook') - def test_execute_diff(self, mock_get_db_hook): - mock_hook = mock.Mock() - - def returned_row(): - rows = [ - [3, 3, 3, 3], # reference - [1, 1, 1, 1], # current - ] - - for r in rows: - yield r - - mock_hook.get_first.side_effect = returned_row() - mock_get_db_hook.return_value = mock_hook - - operator = self._construct_operator( - table='test_table', - metric_thresholds={ - 'f0': 0.5, - 'f1': 0.6, - 'f2': 0.7, - 'f3': 0.8, - }, - ratio_formula='relative_diff', - ignore_zero=True, - ) - - with self.assertRaisesRegexp(AirflowException, "f0, f1"): - operator.execute() - - -class TestThresholdCheckOperator(unittest.TestCase): - - def _construct_operator(self, sql, min_threshold, max_threshold): - dag = DAG('test_dag', start_date=datetime(2017, 1, 1)) - - return ThresholdCheckOperator( - task_id='test_task', - sql=sql, - min_threshold=min_threshold, - max_threshold=max_threshold, - dag=dag - ) - - @mock.patch.object(ThresholdCheckOperator, 'get_db_hook') - def test_pass_min_value_max_value(self, mock_get_db_hook): - mock_hook = mock.Mock() - mock_hook.get_first.return_value = [(10,)] - mock_get_db_hook.return_value = mock_hook - - operator = self._construct_operator( - 'Select avg(val) from table1 limit 1', - 1, - 100 - ) - - operator.execute() - - @mock.patch.object(ThresholdCheckOperator, 'get_db_hook') - def test_fail_min_value_max_value(self, mock_get_db_hook): - mock_hook = mock.Mock() - mock_hook.get_first.return_value = [(10,)] - mock_get_db_hook.return_value = mock_hook - - operator = self._construct_operator( - 'Select avg(val) from table1 limit 1', - 20, - 100 - ) - - with six.assertRaisesRegex(self, AirflowException, '10.*20.0.*100.0'): - operator.execute() - - @mock.patch.object(ThresholdCheckOperator, 'get_db_hook') - def test_pass_min_sql_max_sql(self, mock_get_db_hook): - mock_hook = mock.Mock() - mock_hook.get_first.side_effect = lambda x: [(int(x.split()[1]),)] - mock_get_db_hook.return_value = mock_hook - - operator = self._construct_operator( - 'Select 10', - 'Select 1', - 'Select 100' - ) - - operator.execute() - - @mock.patch.object(ThresholdCheckOperator, 'get_db_hook') - def test_fail_min_sql_max_sql(self, mock_get_db_hook): - mock_hook = mock.Mock() - mock_hook.get_first.side_effect = lambda x: [(int(x.split()[1]),)] - mock_get_db_hook.return_value = mock_hook - - operator = self._construct_operator( - 'Select 10', - 'Select 20', - 'Select 100' - ) - - with six.assertRaisesRegex(self, AirflowException, '10.*20.*100'): - operator.execute() - - @mock.patch.object(ThresholdCheckOperator, 'get_db_hook') - def test_pass_min_value_max_sql(self, mock_get_db_hook): - mock_hook = mock.Mock() - mock_hook.get_first.side_effect = lambda x: [(int(x.split()[1]),)] - mock_get_db_hook.return_value = mock_hook - - operator = self._construct_operator( - 'Select 75', - 45, - 'Select 100' - ) - - operator.execute() - - @mock.patch.object(ThresholdCheckOperator, 'get_db_hook') - def test_fail_min_sql_max_value(self, mock_get_db_hook): - mock_hook = mock.Mock() - mock_hook.get_first.side_effect = lambda x: [(int(x.split()[1]),)] - mock_get_db_hook.return_value = mock_hook - - operator = self._construct_operator( - 'Select 155', - 'Select 45', - 100 - ) - - with six.assertRaisesRegex(self, AirflowException, '155.*45.*100.0'): - operator.execute() diff --git a/tests/operators/test_s3_to_hive_operator.py b/tests/operators/test_s3_to_hive_operator.py index 8366465..8ed7ad2 100644 --- a/tests/operators/test_s3_to_hive_operator.py +++ b/tests/operators/test_s3_to_hive_operator.py @@ -19,12 +19,14 @@ import unittest +import six + +from airflow import AirflowException from tests.compat import mock import logging from itertools import product from airflow.operators.s3_to_hive_operator import S3ToHiveTransfer from collections import OrderedDict -from airflow.exceptions import AirflowException from tempfile import NamedTemporaryFile, mkdtemp from gzip import GzipFile import bz2 @@ -156,10 +158,10 @@ class S3ToHiveTransferTest(unittest.TestCase): def test_bad_parameters(self): self.kwargs['check_headers'] = True self.kwargs['headers'] = False - self.assertRaisesRegexp(AirflowException, - "To check_headers.*", - S3ToHiveTransfer, - **self.kwargs) + six.assertRaisesRegex(self, AirflowException, + "To check_headers.*", + S3ToHiveTransfer, + **self.kwargs) def test__get_top_row_as_list(self): self.kwargs['delimiter'] = '\t' diff --git a/tests/operators/test_sql_branch_operator.py b/tests/operators/test_sql.py similarity index 57% rename from tests/operators/test_sql_branch_operator.py rename to tests/operators/test_sql.py index 6510609..6ccc5fa 100644 --- a/tests/operators/test_sql_branch_operator.py +++ b/tests/operators/test_sql.py @@ -18,14 +18,20 @@ import datetime import unittest + +import six + from tests.compat import mock import pytest from airflow.exceptions import AirflowException from airflow.models import DAG, DagRun, TaskInstance as TI +from airflow.operators.check_operator import ( + CheckOperator, IntervalCheckOperator, ThresholdCheckOperator, ValueCheckOperator, +) from airflow.operators.dummy_operator import DummyOperator -from airflow.operators.sql_branch_operator import BranchSqlOperator +from airflow.operators.sql import BranchSQLOperator from airflow.utils import timezone from airflow.utils.db import create_session from airflow.utils.state import State @@ -60,6 +66,266 @@ SUPPORTED_FALSE_VALUES = [ ] +class TestCheckOperator(unittest.TestCase): + @mock.patch.object(CheckOperator, "get_db_hook") + def test_execute_no_records(self, mock_get_db_hook): + mock_get_db_hook.return_value.get_first.return_value = [] + + with self.assertRaises(AirflowException): + CheckOperator(sql="sql").execute() + + @mock.patch.object(CheckOperator, "get_db_hook") + def test_execute_not_all_records_are_true(self, mock_get_db_hook): + mock_get_db_hook.return_value.get_first.return_value = ["data", ""] + + with self.assertRaises(AirflowException): + CheckOperator(sql="sql").execute() + + +class TestValueCheckOperator(unittest.TestCase): + def setUp(self): + self.task_id = "test_task" + self.conn_id = "default_conn" + + def _construct_operator(self, sql, pass_value, tolerance=None): + dag = DAG("test_dag", start_date=datetime.datetime(2017, 1, 1)) + + return ValueCheckOperator( + dag=dag, + task_id=self.task_id, + conn_id=self.conn_id, + sql=sql, + pass_value=pass_value, + tolerance=tolerance, + ) + + def test_pass_value_template_string(self): + pass_value_str = "2018-03-22" + operator = self._construct_operator( + "select date from tab1;", "{{ ds }}") + + operator.render_template_fields({"ds": pass_value_str}) + + self.assertEqual(operator.task_id, self.task_id) + self.assertEqual(operator.pass_value, pass_value_str) + + def test_pass_value_template_string_float(self): + pass_value_float = 4.0 + operator = self._construct_operator( + "select date from tab1;", pass_value_float) + + operator.render_template_fields({}) + + self.assertEqual(operator.task_id, self.task_id) + self.assertEqual(operator.pass_value, str(pass_value_float)) + + @mock.patch.object(ValueCheckOperator, "get_db_hook") + def test_execute_pass(self, mock_get_db_hook): + mock_hook = mock.Mock() + mock_hook.get_first.return_value = [10] + mock_get_db_hook.return_value = mock_hook + sql = "select value from tab1 limit 1;" + operator = self._construct_operator(sql, 5, 1) + + operator.execute(None) + + mock_hook.get_first.assert_called_once_with(sql) + + @mock.patch.object(ValueCheckOperator, "get_db_hook") + def test_execute_fail(self, mock_get_db_hook): + mock_hook = mock.Mock() + mock_hook.get_first.return_value = [11] + mock_get_db_hook.return_value = mock_hook + + operator = self._construct_operator( + "select value from tab1 limit 1;", 5, 1) + + with six.assertRaisesRegex(self, AirflowException, "Tolerance:100.0%"): + operator.execute() + + +class TestIntervalCheckOperator(unittest.TestCase): + def _construct_operator(self, table, metric_thresholds, ratio_formula, ignore_zero): + return IntervalCheckOperator( + task_id="test_task", + table=table, + metrics_thresholds=metric_thresholds, + ratio_formula=ratio_formula, + ignore_zero=ignore_zero, + ) + + def test_invalid_ratio_formula(self): + with six.assertRaisesRegex(self, AirflowException, "Invalid diff_method"): + self._construct_operator( + table="test_table", + metric_thresholds={"f1": 1, }, + ratio_formula="abs", + ignore_zero=False, + ) + + @mock.patch.object(IntervalCheckOperator, "get_db_hook") + def test_execute_not_ignore_zero(self, mock_get_db_hook): + mock_hook = mock.Mock() + mock_hook.get_first.return_value = [0] + mock_get_db_hook.return_value = mock_hook + + operator = self._construct_operator( + table="test_table", + metric_thresholds={"f1": 1, }, + ratio_formula="max_over_min", + ignore_zero=False, + ) + + with self.assertRaises(AirflowException): + operator.execute() + + @mock.patch.object(IntervalCheckOperator, "get_db_hook") + def test_execute_ignore_zero(self, mock_get_db_hook): + mock_hook = mock.Mock() + mock_hook.get_first.return_value = [0] + mock_get_db_hook.return_value = mock_hook + + operator = self._construct_operator( + table="test_table", + metric_thresholds={"f1": 1, }, + ratio_formula="max_over_min", + ignore_zero=True, + ) + + operator.execute() + + @mock.patch.object(IntervalCheckOperator, "get_db_hook") + def test_execute_min_max(self, mock_get_db_hook): + mock_hook = mock.Mock() + + def returned_row(): + rows = [ + [2, 2, 2, 2], # reference + [1, 1, 1, 1], # current + ] + return rows + + mock_hook.get_first.side_effect = returned_row() + mock_get_db_hook.return_value = mock_hook + + operator = self._construct_operator( + table="test_table", + metric_thresholds={"f0": 1.0, "f1": 1.5, "f2": 2.0, "f3": 2.5, }, + ratio_formula="max_over_min", + ignore_zero=True, + ) + + with six.assertRaisesRegex(self, AirflowException, "f0, f1, f2"): + operator.execute() + + @mock.patch.object(IntervalCheckOperator, "get_db_hook") + def test_execute_diff(self, mock_get_db_hook): + mock_hook = mock.Mock() + + def returned_row(): + rows = [ + [3, 3, 3, 3], # reference + [1, 1, 1, 1], # current + ] + + return rows + + mock_hook.get_first.side_effect = returned_row() + mock_get_db_hook.return_value = mock_hook + + operator = self._construct_operator( + table="test_table", + metric_thresholds={"f0": 0.5, "f1": 0.6, "f2": 0.7, "f3": 0.8, }, + ratio_formula="relative_diff", + ignore_zero=True, + ) + + with six.assertRaisesRegex(self, AirflowException, "f0, f1"): + operator.execute() + + +class TestThresholdCheckOperator(unittest.TestCase): + def _construct_operator(self, sql, min_threshold, max_threshold): + dag = DAG("test_dag", start_date=datetime.datetime(2017, 1, 1)) + + return ThresholdCheckOperator( + task_id="test_task", + sql=sql, + min_threshold=min_threshold, + max_threshold=max_threshold, + dag=dag, + ) + + @mock.patch.object(ThresholdCheckOperator, "get_db_hook") + def test_pass_min_value_max_value(self, mock_get_db_hook): + mock_hook = mock.Mock() + mock_hook.get_first.return_value = [(10,)] + mock_get_db_hook.return_value = mock_hook + + operator = self._construct_operator( + "Select avg(val) from table1 limit 1", 1, 100 + ) + + operator.execute() + + @mock.patch.object(ThresholdCheckOperator, "get_db_hook") + def test_fail_min_value_max_value(self, mock_get_db_hook): + mock_hook = mock.Mock() + mock_hook.get_first.return_value = [(10,)] + mock_get_db_hook.return_value = mock_hook + + operator = self._construct_operator( + "Select avg(val) from table1 limit 1", 20, 100 + ) + + with six.assertRaisesRegex(self, AirflowException, "10.*20.0.*100.0"): + operator.execute() + + @mock.patch.object(ThresholdCheckOperator, "get_db_hook") + def test_pass_min_sql_max_sql(self, mock_get_db_hook): + mock_hook = mock.Mock() + mock_hook.get_first.side_effect = lambda x: [(int(x.split()[1]),)] + mock_get_db_hook.return_value = mock_hook + + operator = self._construct_operator( + "Select 10", "Select 1", "Select 100") + + operator.execute() + + @mock.patch.object(ThresholdCheckOperator, "get_db_hook") + def test_fail_min_sql_max_sql(self, mock_get_db_hook): + mock_hook = mock.Mock() + mock_hook.get_first.side_effect = lambda x: [(int(x.split()[1]),)] + mock_get_db_hook.return_value = mock_hook + + operator = self._construct_operator( + "Select 10", "Select 20", "Select 100") + + with six.assertRaisesRegex(self, AirflowException, "10.*20.*100"): + operator.execute() + + @mock.patch.object(ThresholdCheckOperator, "get_db_hook") + def test_pass_min_value_max_sql(self, mock_get_db_hook): + mock_hook = mock.Mock() + mock_hook.get_first.side_effect = lambda x: [(int(x.split()[1]),)] + mock_get_db_hook.return_value = mock_hook + + operator = self._construct_operator("Select 75", 45, "Select 100") + + operator.execute() + + @mock.patch.object(ThresholdCheckOperator, "get_db_hook") + def test_fail_min_sql_max_value(self, mock_get_db_hook): + mock_hook = mock.Mock() + mock_hook.get_first.side_effect = lambda x: [(int(x.split()[1]),)] + mock_get_db_hook.return_value = mock_hook + + operator = self._construct_operator("Select 155", "Select 45", 100) + + with six.assertRaisesRegex(self, AirflowException, "155.*45.*100.0"): + operator.execute() + + class TestSqlBranch(TestHiveEnvironment, unittest.TestCase): """ Test for SQL Branch Operator @@ -92,8 +358,8 @@ class TestSqlBranch(TestHiveEnvironment, unittest.TestCase): session.query(TI).delete() def test_unsupported_conn_type(self): - """ Check if BranchSqlOperator throws an exception for unsupported connection type """ - op = BranchSqlOperator( + """ Check if BranchSQLOperator throws an exception for unsupported connection type """ + op = BranchSQLOperator( task_id="make_choice", conn_id="redis_default", sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES", @@ -103,11 +369,12 @@ class TestSqlBranch(TestHiveEnvironment, unittest.TestCase): ) with self.assertRaises(AirflowException): - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + op.run(start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE, ignore_ti_state=True) def test_invalid_conn(self): - """ Check if BranchSqlOperator throws an exception for invalid connection """ - op = BranchSqlOperator( + """ Check if BranchSQLOperator throws an exception for invalid connection """ + op = BranchSQLOperator( task_id="make_choice", conn_id="invalid_connection", sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES", @@ -117,11 +384,12 @@ class TestSqlBranch(TestHiveEnvironment, unittest.TestCase): ) with self.assertRaises(AirflowException): - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + op.run(start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE, ignore_ti_state=True) def test_invalid_follow_task_true(self): - """ Check if BranchSqlOperator throws an exception for invalid connection """ - op = BranchSqlOperator( + """ Check if BranchSQLOperator throws an exception for invalid connection """ + op = BranchSQLOperator( task_id="make_choice", conn_id="invalid_connection", sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES", @@ -131,11 +399,12 @@ class TestSqlBranch(TestHiveEnvironment, unittest.TestCase): ) with self.assertRaises(AirflowException): - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + op.run(start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE, ignore_ti_state=True) def test_invalid_follow_task_false(self): - """ Check if BranchSqlOperator throws an exception for invalid connection """ - op = BranchSqlOperator( + """ Check if BranchSQLOperator throws an exception for invalid connection """ + op = BranchSQLOperator( task_id="make_choice", conn_id="invalid_connection", sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES", @@ -145,12 +414,13 @@ class TestSqlBranch(TestHiveEnvironment, unittest.TestCase): ) with self.assertRaises(AirflowException): - op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + op.run(start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE, ignore_ti_state=True) @pytest.mark.backend("mysql") def test_sql_branch_operator_mysql(self): - """ Check if BranchSqlOperator works with backend """ - branch_op = BranchSqlOperator( + """ Check if BranchSQLOperator works with backend """ + branch_op = BranchSQLOperator( task_id="make_choice", conn_id="mysql_default", sql="SELECT 1", @@ -164,8 +434,8 @@ class TestSqlBranch(TestHiveEnvironment, unittest.TestCase): @pytest.mark.backend("postgres") def test_sql_branch_operator_postgres(self): - """ Check if BranchSqlOperator works with backend """ - branch_op = BranchSqlOperator( + """ Check if BranchSQLOperator works with backend """ + branch_op = BranchSQLOperator( task_id="make_choice", conn_id="postgres_default", sql="SELECT 1", @@ -177,10 +447,10 @@ class TestSqlBranch(TestHiveEnvironment, unittest.TestCase): start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True ) - @mock.patch("airflow.operators.sql_branch_operator.BaseHook") + @mock.patch("airflow.operators.sql.BaseHook") def test_branch_single_value_with_dag_run(self, mock_hook): - """ Check BranchSqlOperator branch operation """ - branch_op = BranchSqlOperator( + """ Check BranchSQLOperator branch operation """ + branch_op = BranchSQLOperator( task_id="make_choice", conn_id="mysql_default", sql="SELECT 1", @@ -220,10 +490,10 @@ class TestSqlBranch(TestHiveEnvironment, unittest.TestCase): else: raise ValueError("Invalid task id {task_id} found!".format(task_id=ti.task_id)) - @mock.patch("airflow.operators.sql_branch_operator.BaseHook") + @mock.patch("airflow.operators.sql.BaseHook") def test_branch_true_with_dag_run(self, mock_hook): - """ Check BranchSqlOperator branch operation """ - branch_op = BranchSqlOperator( + """ Check BranchSQLOperator branch operation """ + branch_op = BranchSQLOperator( task_id="make_choice", conn_id="mysql_default", sql="SELECT 1", @@ -264,10 +534,10 @@ class TestSqlBranch(TestHiveEnvironment, unittest.TestCase): else: raise ValueError("Invalid task id {task_id} found!".format(task_id=ti.task_id)) - @mock.patch("airflow.operators.sql_branch_operator.BaseHook") + @mock.patch("airflow.operators.sql.BaseHook") def test_branch_false_with_dag_run(self, mock_hook): - """ Check BranchSqlOperator branch operation """ - branch_op = BranchSqlOperator( + """ Check BranchSQLOperator branch operation """ + branch_op = BranchSQLOperator( task_id="make_choice", conn_id="mysql_default", sql="SELECT 1", @@ -308,10 +578,10 @@ class TestSqlBranch(TestHiveEnvironment, unittest.TestCase): else: raise ValueError("Invalid task id {task_id} found!".format(task_id=ti.task_id)) - @mock.patch("airflow.operators.sql_branch_operator.BaseHook") + @mock.patch("airflow.operators.sql.BaseHook") def test_branch_list_with_dag_run(self, mock_hook): - """ Checks if the BranchSqlOperator supports branching off to a list of tasks.""" - branch_op = BranchSqlOperator( + """ Checks if the BranchSQLOperator supports branching off to a list of tasks.""" + branch_op = BranchSQLOperator( task_id="make_choice", conn_id="mysql_default", sql="SELECT 1", @@ -354,10 +624,10 @@ class TestSqlBranch(TestHiveEnvironment, unittest.TestCase): else: raise ValueError("Invalid task id {task_id} found!".format(task_id=ti.task_id)) - @mock.patch("airflow.operators.sql_branch_operator.BaseHook") + @mock.patch("airflow.operators.sql.BaseHook") def test_invalid_query_result_with_dag_run(self, mock_hook): - """ Check BranchSqlOperator branch operation """ - branch_op = BranchSqlOperator( + """ Check BranchSQLOperator branch operation """ + branch_op = BranchSQLOperator( task_id="make_choice", conn_id="mysql_default", sql="SELECT 1", @@ -387,10 +657,10 @@ class TestSqlBranch(TestHiveEnvironment, unittest.TestCase): with self.assertRaises(AirflowException): branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - @mock.patch("airflow.operators.sql_branch_operator.BaseHook") + @mock.patch("airflow.operators.sql.BaseHook") def test_with_skip_in_branch_downstream_dependencies(self, mock_hook): """ Test SQL Branch with skipping all downstream dependencies """ - branch_op = BranchSqlOperator( + branch_op = BranchSQLOperator( task_id="make_choice", conn_id="mysql_default", sql="SELECT 1", @@ -431,10 +701,10 @@ class TestSqlBranch(TestHiveEnvironment, unittest.TestCase): else: raise ValueError("Invalid task id {task_id} found!".format(task_id=ti.task_id)) - @mock.patch("airflow.operators.sql_branch_operator.BaseHook") + @mock.patch("airflow.operators.sql.BaseHook") def test_with_skip_in_branch_downstream_dependencies2(self, mock_hook): """ Test skipping downstream dependency for false condition""" - branch_op = BranchSqlOperator( + branch_op = BranchSQLOperator( task_id="make_choice", conn_id="mysql_default", sql="SELECT 1", diff --git a/tests/secrets/test_local_filesystem.py b/tests/secrets/test_local_filesystem.py index 60cec06..dc06969 100644 --- a/tests/secrets/test_local_filesystem.py +++ b/tests/secrets/test_local_filesystem.py @@ -48,7 +48,7 @@ class FileParsers(unittest.TestCase): ) def test_env_file_invalid_format(self, content, expected_message): with mock_local_file(content): - with self.assertRaisesRegexp(AirflowFileParseException, re.escape(expected_message)): + with six.assertRaisesRegex(self, AirflowFileParseException, re.escape(expected_message)): local_filesystem.load_variables("a.env") @parameterized.expand( @@ -65,7 +65,7 @@ class FileParsers(unittest.TestCase): ) def test_json_file_invalid_format(self, content, expected_message): with mock_local_file(content): - with self.assertRaisesRegexp(AirflowFileParseException, re.escape(expected_message)): + with six.assertRaisesRegex(self, AirflowFileParseException, re.escape(expected_message)): local_filesystem.load_variables("a.json") @@ -87,7 +87,7 @@ class TestLoadVariables(unittest.TestCase): @parameterized.expand((("AA=A\nAA=B", "The \"a.env\" file contains multiple values for keys: ['AA']"),)) def test_env_file_invalid_logic(self, content, expected_message): with mock_local_file(content): - with self.assertRaisesRegexp(AirflowException, re.escape(expected_message)): + with six.assertRaisesRegex(self, AirflowException, re.escape(expected_message)): local_filesystem.load_variables("a.env") @parameterized.expand( @@ -105,7 +105,8 @@ class TestLoadVariables(unittest.TestCase): @mock.patch("airflow.secrets.local_filesystem.os.path.exists", return_value=False) def test_missing_file(self, mock_exists): - with self.assertRaisesRegexp( + with six.assertRaisesRegex( + self, AirflowException, re.escape("File a.json was not found. Check the configuration of your Secrets backend."), ): @@ -148,7 +149,7 @@ class TestLoadConnection(unittest.TestCase): ) def test_env_file_invalid_format(self, content, expected_message): with mock_local_file(content): - with self.assertRaisesRegexp(AirflowFileParseException, re.escape(expected_message)): + with six.assertRaisesRegex(self, AirflowFileParseException, re.escape(expected_message)): local_filesystem.load_connections("a.env") @parameterized.expand( @@ -189,12 +190,13 @@ class TestLoadConnection(unittest.TestCase): ) def test_env_file_invalid_input(self, file_content, expected_connection_uris): with mock_local_file(json.dumps(file_content)): - with self.assertRaisesRegexp(AirflowException, re.escape(expected_connection_uris)): + with six.assertRaisesRegex(self, AirflowException, re.escape(expected_connection_uris)): local_filesystem.load_connections("a.json") @mock.patch("airflow.secrets.local_filesystem.os.path.exists", return_value=False) def test_missing_file(self, mock_exists): - with self.assertRaisesRegexp( + with six.assertRaisesRegex( + self, AirflowException, re.escape("File a.json was not found. Check the configuration of your Secrets backend."), ): diff --git a/tests/sensors/test_http_sensor.py b/tests/sensors/test_http_sensor.py index 5b2d19e..db0f02a 100644 --- a/tests/sensors/test_http_sensor.py +++ b/tests/sensors/test_http_sensor.py @@ -19,6 +19,7 @@ import unittest import requests +import six from mock import patch from airflow import DAG @@ -61,7 +62,7 @@ class HttpSensorTests(unittest.TestCase): response_check=resp_check, timeout=5, poke_interval=1) - with self.assertRaisesRegexp(AirflowException, 'AirflowException raised here!'): + with six.assertRaisesRegex(self, AirflowException, 'AirflowException raised here!'): task.execute(None) @patch("airflow.hooks.http_hook.requests.Session.send") diff --git a/tests/utils/test_compression.py b/tests/utils/test_compression.py index 5a36709..022d981 100644 --- a/tests/utils/test_compression.py +++ b/tests/utils/test_compression.py @@ -26,6 +26,8 @@ import shutil import tempfile import unittest +import six + from airflow.utils import compression @@ -81,13 +83,13 @@ class Compression(unittest.TestCase): def test_uncompress_file(self): # Testing txt file type - self.assertRaisesRegexp(NotImplementedError, - "^Received .txt format. Only gz and bz2.*", - compression.uncompress_file, - **{'input_file_name': None, - 'file_extension': '.txt', - 'dest_dir': None - }) + six.assertRaisesRegex(self, NotImplementedError, + "^Received .txt format. Only gz and bz2.*", + compression.uncompress_file, + **{'input_file_name': None, + 'file_extension': '.txt', + 'dest_dir': None + }) # Testing gz file type fn_txt = self._get_fn('.txt') fn_gz = self._get_fn('.gz') diff --git a/tests/utils/test_decorators.py b/tests/utils/test_decorators.py index d23cdcc..05df91b 100644 --- a/tests/utils/test_decorators.py +++ b/tests/utils/test_decorators.py @@ -19,6 +19,8 @@ import unittest +import six + from airflow.utils.decorators import apply_defaults from airflow.exceptions import AirflowException @@ -43,7 +45,7 @@ class ApplyDefaultTest(unittest.TestCase): dc = DummyClass(test_param=True) self.assertTrue(dc.test_param) - with self.assertRaisesRegexp(AirflowException, 'Argument.*test_param.*required'): + with six.assertRaisesRegex(self, AirflowException, 'Argument.*test_param.*required'): DummySubClass(test_sub_param=True) def test_default_args(self): @@ -61,8 +63,8 @@ class ApplyDefaultTest(unittest.TestCase): self.assertTrue(dc.test_param) self.assertTrue(dsc.test_sub_param) - with self.assertRaisesRegexp(AirflowException, - 'Argument.*test_sub_param.*required'): + with six.assertRaisesRegex(self, AirflowException, + 'Argument.*test_sub_param.*required'): DummySubClass(default_args=default_args) def test_incorrect_default_args(self): @@ -71,5 +73,5 @@ class ApplyDefaultTest(unittest.TestCase): self.assertTrue(dc.test_param) default_args = {'random_params': True} - with self.assertRaisesRegexp(AirflowException, 'Argument.*test_param.*required'): + with six.assertRaisesRegex(self, AirflowException, 'Argument.*test_param.*required'): DummyClass(default_args=default_args) diff --git a/tests/utils/test_json.py b/tests/utils/test_json.py index ce0eece..bc3c7f9 100644 --- a/tests/utils/test_json.py +++ b/tests/utils/test_json.py @@ -22,6 +22,7 @@ import json import unittest import numpy as np +import six from airflow.utils import json as utils_json @@ -76,11 +77,11 @@ class TestAirflowJsonEncoder(unittest.TestCase): ) def test_encode_raises(self): - self.assertRaisesRegexp(TypeError, - "^.*is not JSON serializable$", - json.dumps, - Exception, - cls=utils_json.AirflowJsonEncoder) + six.assertRaisesRegex(self, TypeError, + "^.*is not JSON serializable$", + json.dumps, + Exception, + cls=utils_json.AirflowJsonEncoder) if __name__ == '__main__': diff --git a/tests/utils/test_module_loading.py b/tests/utils/test_module_loading.py index ba1ebca..cde32c5 100644 --- a/tests/utils/test_module_loading.py +++ b/tests/utils/test_module_loading.py @@ -19,6 +19,8 @@ import unittest +import six + from airflow.utils.module_loading import import_string @@ -31,5 +33,5 @@ class ModuleImportTestCase(unittest.TestCase): with self.assertRaises(ImportError): import_string('no_dots_in_path') msg = 'Module "airflow.utils" does not define a "nonexistent" attribute' - with self.assertRaisesRegexp(ImportError, msg): + with six.assertRaisesRegex(self, ImportError, msg): import_string('airflow.utils.nonexistent') diff --git a/tests/www/test_validators.py b/tests/www/test_validators.py index e624263..6b4fcbd 100644 --- a/tests/www/test_validators.py +++ b/tests/www/test_validators.py @@ -20,6 +20,8 @@ import mock import unittest +import six + from airflow.www import validators @@ -46,7 +48,8 @@ class TestGreaterEqualThan(unittest.TestCase): return validator(self.form_mock, self.form_field_mock) def test_field_not_found(self): - self.assertRaisesRegexp( + six.assertRaisesRegex( + self, validators.ValidationError, "^Invalid field name 'some'.$", self._validate, @@ -75,7 +78,8 @@ class TestGreaterEqualThan(unittest.TestCase): def test_validation_raises(self): self.form_field_mock.data = '2017-05-04' - self.assertRaisesRegexp( + six.assertRaisesRegex( + self, validators.ValidationError, "^Field must be greater than or equal to other field.$", self._validate, @@ -84,7 +88,8 @@ class TestGreaterEqualThan(unittest.TestCase): def test_validation_raises_custom_message(self): self.form_field_mock.data = '2017-05-04' - self.assertRaisesRegexp( + six.assertRaisesRegex( + self, validators.ValidationError, "^This field must be greater than or equal to MyField.$", self._validate, diff --git a/tests/www_rbac/test_validators.py b/tests/www_rbac/test_validators.py index 4a543ff..95c7562 100644 --- a/tests/www_rbac/test_validators.py +++ b/tests/www_rbac/test_validators.py @@ -48,7 +48,8 @@ class TestGreaterEqualThan(unittest.TestCase): return validator(self.form_mock, self.form_field_mock) def test_field_not_found(self): - self.assertRaisesRegexp( + six.assertRaisesRegex( + self, validators.ValidationError, "^Invalid field name 'some'.$", self._validate, @@ -77,7 +78,8 @@ class TestGreaterEqualThan(unittest.TestCase): def test_validation_raises(self): self.form_field_mock.data = '2017-05-04' - self.assertRaisesRegexp( + six.assertRaisesRegex( + self, validators.ValidationError, "^Field must be greater than or equal to other field.$", self._validate, @@ -86,7 +88,8 @@ class TestGreaterEqualThan(unittest.TestCase): def test_validation_raises_custom_message(self): self.form_field_mock.data = '2017-05-04' - self.assertRaisesRegexp( + six.assertRaisesRegex( + self, validators.ValidationError, "^This field must be greater than or equal to MyField.$", self._validate,
