This is an automated email from the ASF dual-hosted git repository. dimberman pushed a commit to branch v1-10-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit eb1557173aee6340bd3595a341aebbcc70bc689f Author: Daniel Imberman <[email protected]> AuthorDate: Fri Jun 19 15:15:50 2020 -0700 flake8 pass Merging multiple sql operators (#9124) --- airflow/operators/check_operator.py | 8 +-- airflow/operators/sql.py | 93 +++++++++++++++++--------------- airflow/operators/sql_branch_operator.py | 2 +- tests/operators/test_sql.py | 5 +- tests/test_core_to_contrib.py | 15 ++---- 5 files changed, 61 insertions(+), 62 deletions(-) diff --git a/airflow/operators/check_operator.py b/airflow/operators/check_operator.py index 4810eeb..12ac472 100644 --- a/airflow/operators/check_operator.py +++ b/airflow/operators/check_operator.py @@ -38,7 +38,7 @@ class CheckOperator(SQLCheckOperator): Please use `airflow.operators.sql.SQLCheckOperator`.""", DeprecationWarning, stacklevel=2 ) - super().__init__(*args, **kwargs) + super(CheckOperator, self).__init__(*args, **kwargs) class IntervalCheckOperator(SQLIntervalCheckOperator): @@ -53,7 +53,7 @@ class IntervalCheckOperator(SQLIntervalCheckOperator): Please use `airflow.operators.sql.SQLIntervalCheckOperator`.""", DeprecationWarning, stacklevel=2 ) - super().__init__(*args, **kwargs) + super(IntervalCheckOperator, self).__init__(*args, **kwargs) class ThresholdCheckOperator(SQLThresholdCheckOperator): @@ -68,7 +68,7 @@ class ThresholdCheckOperator(SQLThresholdCheckOperator): Please use `airflow.operators.sql.SQLThresholdCheckOperator`.""", DeprecationWarning, stacklevel=2 ) - super().__init__(*args, **kwargs) + super(ThresholdCheckOperator, self).__init__(*args, **kwargs) class ValueCheckOperator(SQLValueCheckOperator): @@ -83,4 +83,4 @@ class ValueCheckOperator(SQLValueCheckOperator): Please use `airflow.operators.sql.SQLValueCheckOperator`.""", DeprecationWarning, stacklevel=2 ) - super().__init__(*args, **kwargs) + super(ValueCheckOperator, self).__init__(*args, **kwargs) diff --git a/airflow/operators/sql.py b/airflow/operators/sql.py index fd997d9..83cb201 100644 --- a/airflow/operators/sql.py +++ b/airflow/operators/sql.py @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. from distutils.util import strtobool -from typing import Any, Dict, Iterable, List, Mapping, Optional, SupportsAbs, Union +from typing import Iterable from airflow.exceptions import AirflowException from airflow.hooks.base_hook import BaseHook @@ -82,9 +82,9 @@ class SQLCheckOperator(BaseOperator): @apply_defaults def __init__( - self, sql: str, conn_id: Optional[str] = None, *args, **kwargs - ) -> None: - super().__init__(*args, **kwargs) + self, sql, conn_id=None, *args, **kwargs + ): + super(SQLCheckOperator, self).__init__(*args, **kwargs) self.conn_id = conn_id self.sql = sql @@ -155,14 +155,14 @@ class SQLValueCheckOperator(BaseOperator): @apply_defaults def __init__( self, - sql: str, - pass_value: Any, - tolerance: Any = None, - conn_id: Optional[str] = None, + sql, + pass_value, + tolerance=None, + conn_id=None, *args, - **kwargs, + **kwargs ): - super().__init__(*args, **kwargs) + super(SQLValueCheckOperator, self).__init__(*args, **kwargs) self.sql = sql self.conn_id = conn_id self.pass_value = str(pass_value) @@ -278,17 +278,17 @@ class SQLIntervalCheckOperator(BaseOperator): @apply_defaults def __init__( self, - table: str, - metrics_thresholds: Dict[str, int], - date_filter_column: Optional[str] = "ds", - days_back: SupportsAbs[int] = -7, - ratio_formula: Optional[str] = "max_over_min", - ignore_zero: Optional[bool] = True, - conn_id: Optional[str] = None, + table, + metrics_thresholds, + date_filter_column="ds", + days_back=-7, + ratio_formula="max_over_min", + ignore_zero=True, + conn_id=None, *args, - **kwargs, + **kwargs ): - super().__init__(*args, **kwargs) + super(SQLIntervalCheckOperator, self).__init__(*args, **kwargs) if ratio_formula not in self.ratio_formulas: msg_template = ( "Invalid diff_method: {diff_method}. " @@ -423,14 +423,14 @@ class SQLThresholdCheckOperator(BaseOperator): @apply_defaults def __init__( self, - sql: str, - min_threshold: Any, - max_threshold: Any, - conn_id: Optional[str] = None, + sql, + min_threshold, + max_threshold, + conn_id=None, *args, - **kwargs, + **kwargs ): - super().__init__(*args, **kwargs) + super(SQLThresholdCheckOperator, self).__init__(*args, **kwargs) self.sql = sql self.conn_id = conn_id self.min_threshold = _convert_to_float_if_possible(min_threshold) @@ -461,13 +461,20 @@ class SQLThresholdCheckOperator(BaseOperator): self.push(meta_data) if not meta_data["within_threshold"]: error_msg = ( - f'Threshold Check: "{meta_data.get("task_id")}" failed.\n' - f'DAG: {self.dag_id}\nTask_id: {meta_data.get("task_id")}\n' - f'Check description: {meta_data.get("description")}\n' - f"SQL: {self.sql}\n" - f'Result: {round(meta_data.get("result"), 2)} is not within thresholds ' - f'{meta_data.get("min_threshold")} and {meta_data.get("max_threshold")}' - ) + 'Threshold Check: "{task_id}" failed.\n' + 'DAG: {dag_id}\nTask_id: {task_id}\n' + 'Check description: {description}\n' + "SQL: {sql}\n" + 'Result: {round} is not within thresholds ' + '{min} and {max}' + .format(task_id=meta_data.get("task_id"), + dag_id=self.dag_id, + description=meta_data.get("description"), + sql=self.sql, + round=round(meta_data.get("result"), 2), + min=meta_data.get("min_threshold"), + max=meta_data.get("max_threshold"), + )) raise AirflowException(error_msg) self.log.info("Test %s Successful.", self.task_id) @@ -478,7 +485,7 @@ class SQLThresholdCheckOperator(BaseOperator): Default functionality will log metadata. """ - info = "\n".join([f"""{key}: {item}""" for key, item in meta_data.items()]) + info = "\n".join(["{key}: {item}".format(key=key, item=item) for key, item in meta_data.items()]) self.log.info("Log from %s:\n%s", self.dag_id, info) def get_db_hook(self): @@ -516,16 +523,16 @@ class BranchSQLOperator(BaseOperator, SkipMixin): @apply_defaults def __init__( self, - sql: str, - follow_task_ids_if_true: List[str], - follow_task_ids_if_false: List[str], - conn_id: str = "default_conn_id", - database: Optional[str] = None, - parameters: Optional[Union[Mapping, Iterable]] = None, + sql, + follow_task_ids_if_true, + follow_task_ids_if_false, + conn_id="default_conn_id", + database=None, + parameters=None, *args, - **kwargs, - ) -> None: - super().__init__(*args, **kwargs) + **kwargs + ): + super(BranchSQLOperator, self).__init__(*args, **kwargs) self.conn_id = conn_id self.sql = sql self.parameters = parameters @@ -551,7 +558,7 @@ class BranchSQLOperator(BaseOperator, SkipMixin): return self._hook - def execute(self, context: Dict): + def execute(self, context): # get supported hook self._hook = self._get_hook() diff --git a/airflow/operators/sql_branch_operator.py b/airflow/operators/sql_branch_operator.py index cd319aa..b911e34 100644 --- a/airflow/operators/sql_branch_operator.py +++ b/airflow/operators/sql_branch_operator.py @@ -32,4 +32,4 @@ class BranchSqlOperator(BranchSQLOperator): Please use `airflow.operators.sql.BranchSQLOperator`.""", DeprecationWarning, stacklevel=2 ) - super().__init__(*args, **kwargs) + super(BranchSqlOperator, self).__init__(*args, **kwargs) diff --git a/tests/operators/test_sql.py b/tests/operators/test_sql.py index a538f15..e5c1f98 100644 --- a/tests/operators/test_sql.py +++ b/tests/operators/test_sql.py @@ -200,8 +200,7 @@ class TestIntervalCheckOperator(unittest.TestCase): [2, 2, 2, 2], # reference [1, 1, 1, 1], # current ] - - yield from rows + return rows mock_hook.get_first.side_effect = returned_row() mock_get_db_hook.return_value = mock_hook @@ -226,7 +225,7 @@ class TestIntervalCheckOperator(unittest.TestCase): [1, 1, 1, 1], # current ] - yield from rows + return rows mock_hook.get_first.side_effect = returned_row() mock_get_db_hook.return_value = mock_hook diff --git a/tests/test_core_to_contrib.py b/tests/test_core_to_contrib.py index 0a3e7fb..127905a 100644 --- a/tests/test_core_to_contrib.py +++ b/tests/test_core_to_contrib.py @@ -19,12 +19,10 @@ import importlib import sys from inspect import isabstract -from typing import Any from unittest import TestCase, mock from parameterized import parameterized -HOOKS = [] OPERATORS = [ ( @@ -49,24 +47,19 @@ OPERATORS = [ ), ] -SECRETS = [] -SENSORS = [] - -TRANSFERS = [] - -ALL = HOOKS + OPERATORS + SECRETS + SENSORS + TRANSFERS +ALL = OPERATORS RENAMED_HOOKS = [ (old_class, new_class) - for old_class, new_class in HOOKS + OPERATORS + SECRETS + SENSORS + for old_class, new_class in OPERATORS if old_class.rpartition(".")[2] != new_class.rpartition(".")[2] ] class TestMovingCoreToContrib(TestCase): @staticmethod - def assert_warning(msg: str, warning: Any): + def assert_warning(msg, warning): error = "Text '{}' not in warnings".format(msg) assert any(msg in str(w) for w in warning.warnings), error @@ -100,7 +93,7 @@ class TestMovingCoreToContrib(TestCase): class_ = getattr(module, class_name) if isabstract(class_) and not parent: - class_name = f"Mock({class_.__name__})" + class_name = "Mock({class_name})".format(class_name=class_.__name__) attributes = { a: mock.MagicMock() for a in class_.__abstractmethods__
