This is an automated email from the ASF dual-hosted git repository. ash pushed a commit to branch task-sdk-first-code in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 389942917dc437c99a7697f213a708726ee1386c Author: Ash Berlin-Taylor <[email protected]> AuthorDate: Tue Oct 29 20:00:15 2024 +0000 fix more tests [skip ci] --- .../airflow/providers/apache/drill/operators/drill.py | 4 ++-- .../src/airflow/providers/common/sql/operators/sql.py | 18 +++++++++--------- .../src/airflow/providers/common/sql/operators/sql.pyi | 18 +++++++++--------- .../providers/databricks/operators/databricks_sql.py | 4 ++-- .../src/airflow/providers/exasol/operators/exasol.py | 4 ++-- providers/src/airflow/providers/jdbc/operators/jdbc.py | 4 ++-- .../providers/microsoft/mssql/operators/mssql.py | 4 ++-- .../src/airflow/providers/mysql/operators/mysql.py | 4 ++-- .../src/airflow/providers/oracle/operators/oracle.py | 4 ++-- .../airflow/providers/postgres/operators/postgres.py | 7 +++++-- .../airflow/providers/snowflake/operators/snowflake.py | 4 ++-- .../src/airflow/providers/sqlite/operators/sqlite.py | 4 ++-- .../airflow/providers/teradata/operators/teradata.py | 4 ++-- .../src/airflow/providers/trino/operators/trino.py | 4 ++-- .../src/airflow/providers/vertica/operators/vertica.py | 4 ++-- .../tests/google/cloud/operators/test_bigquery.py | 4 ++-- providers/tests/google/cloud/operators/test_compute.py | 6 ++++-- .../tests/google/cloud/operators/test_dataflow.py | 2 +- 18 files changed, 54 insertions(+), 49 deletions(-) diff --git a/providers/src/airflow/providers/apache/drill/operators/drill.py b/providers/src/airflow/providers/apache/drill/operators/drill.py index 5aa0f061baf..edf9d0f7359 100644 --- a/providers/src/airflow/providers/apache/drill/operators/drill.py +++ b/providers/src/airflow/providers/apache/drill/operators/drill.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from typing import Sequence +from typing import ClassVar, Sequence from deprecated import deprecated @@ -46,7 +46,7 @@ class DrillOperator(SQLExecuteQueryOperator): """ template_fields: Sequence[str] = ("sql",) - template_fields_renderers = {"sql": "sql"} + template_fields_renderers: ClassVar[dict] = {"sql": "sql"} template_ext: Sequence[str] = (".sql",) ui_color = "#ededed" diff --git a/providers/src/airflow/providers/common/sql/operators/sql.py b/providers/src/airflow/providers/common/sql/operators/sql.py index dae389be028..44982b8e1ef 100644 --- a/providers/src/airflow/providers/common/sql/operators/sql.py +++ b/providers/src/airflow/providers/common/sql/operators/sql.py @@ -20,7 +20,7 @@ from __future__ import annotations import ast import re from functools import cached_property -from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, NoReturn, Sequence, SupportsAbs +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Mapping, NoReturn, Sequence, SupportsAbs from airflow.exceptions import AirflowException, AirflowFailException from airflow.hooks.base import BaseHook @@ -224,7 +224,7 @@ class SQLExecuteQueryOperator(BaseSQLOperator): template_fields: Sequence[str] = ("sql", "parameters", *BaseSQLOperator.template_fields) template_ext: Sequence[str] = (".sql", ".json") - template_fields_renderers = {"sql": "sql", "parameters": "json"} + template_fields_renderers: ClassVar[dict] = {"sql": "sql", "parameters": "json"} ui_color = "#cdaaed" def __init__( @@ -428,7 +428,7 @@ class SQLColumnCheckOperator(BaseSQLOperator): """ template_fields: Sequence[str] = ("table", "partition_clause", "sql", *BaseSQLOperator.template_fields) - template_fields_renderers = {"sql": "sql"} + template_fields_renderers: ClassVar[dict] = {"sql": "sql"} sql_check_template = """ SELECT '{column}' AS col_name, '{check}' AS check_type, {column}_{check} AS check_result @@ -657,7 +657,7 @@ class SQLTableCheckOperator(BaseSQLOperator): template_fields: Sequence[str] = ("table", "partition_clause", "sql", *BaseSQLOperator.template_fields) - template_fields_renderers = {"sql": "sql"} + template_fields_renderers: ClassVar[dict] = {"sql": "sql"} sql_check_template = """ SELECT '{check_name}' AS check_name, MIN({check_name}) AS check_result @@ -776,7 +776,7 @@ class SQLCheckOperator(BaseSQLOperator): ".hql", ".sql", ) - template_fields_renderers = {"sql": "sql"} + template_fields_renderers: ClassVar[dict] = {"sql": "sql"} ui_color = "#fff7e6" def __init__( @@ -822,7 +822,7 @@ class SQLValueCheckOperator(BaseSQLOperator): ".hql", ".sql", ) - template_fields_renderers = {"sql": "sql"} + template_fields_renderers: ClassVar[dict] = {"sql": "sql"} ui_color = "#fff7e6" def __init__( @@ -919,7 +919,7 @@ class SQLIntervalCheckOperator(BaseSQLOperator): ".hql", ".sql", ) - template_fields_renderers = {"sql1": "sql", "sql2": "sql"} + template_fields_renderers: ClassVar[dict] = {"sql1": "sql", "sql2": "sql"} ui_color = "#fff7e6" ratio_formulas = { @@ -1052,7 +1052,7 @@ class SQLThresholdCheckOperator(BaseSQLOperator): ".hql", ".sql", ) - template_fields_renderers = {"sql": "sql"} + template_fields_renderers: ClassVar[dict] = {"sql": "sql"} def __init__( self, @@ -1147,7 +1147,7 @@ class BranchSQLOperator(BaseSQLOperator, SkipMixin): template_fields: Sequence[str] = ("sql", *BaseSQLOperator.template_fields) template_ext: Sequence[str] = (".sql",) - template_fields_renderers = {"sql": "sql"} + template_fields_renderers: ClassVar[dict] = {"sql": "sql"} ui_color = "#a22034" ui_fgcolor = "#F7F7F7" diff --git a/providers/src/airflow/providers/common/sql/operators/sql.pyi b/providers/src/airflow/providers/common/sql/operators/sql.pyi index 1b97cec5023..6921e3411ea 100644 --- a/providers/src/airflow/providers/common/sql/operators/sql.pyi +++ b/providers/src/airflow/providers/common/sql/operators/sql.pyi @@ -36,7 +36,7 @@ from airflow.models import BaseOperator as BaseOperator, SkipMixin as SkipMixin from airflow.providers.common.sql.hooks.sql import DbApiHook as DbApiHook from airflow.providers.openlineage.extractors import OperatorLineage as OperatorLineage from airflow.utils.context import Context as Context -from typing import Any, Callable, Iterable, Mapping, Sequence, SupportsAbs +from typing import Any, Callable, ClassVar, Iterable, Mapping, Sequence, SupportsAbs def parse_boolean(val: str) -> str | bool: ... @@ -62,7 +62,7 @@ class SQLExecuteQueryOperator(BaseSQLOperator): def _raise_exception(self, exception_string: str) -> Incomplete: ... template_fields: Sequence[str] template_ext: Sequence[str] - template_fields_renderers: Incomplete + template_fields_renderers: ClassVar[dict] ui_color: str sql: Incomplete autocommit: Incomplete @@ -92,7 +92,7 @@ class SQLExecuteQueryOperator(BaseSQLOperator): class SQLColumnCheckOperator(BaseSQLOperator): template_fields: Sequence[str] - template_fields_renderers: Incomplete + template_fields_renderers: ClassVar[dict] sql_check_template: str column_checks: Incomplete table: Incomplete @@ -115,7 +115,7 @@ class SQLColumnCheckOperator(BaseSQLOperator): class SQLTableCheckOperator(BaseSQLOperator): template_fields: Sequence[str] - template_fields_renderers: Incomplete + template_fields_renderers: ClassVar[dict] sql_check_template: str table: Incomplete checks: Incomplete @@ -136,7 +136,7 @@ class SQLTableCheckOperator(BaseSQLOperator): class SQLCheckOperator(BaseSQLOperator): template_fields: Sequence[str] template_ext: Sequence[str] - template_fields_renderers: Incomplete + template_fields_renderers: ClassVar[dict] ui_color: str sql: Incomplete parameters: Incomplete @@ -155,7 +155,7 @@ class SQLValueCheckOperator(BaseSQLOperator): __mapper_args__: Incomplete template_fields: Sequence[str] template_ext: Sequence[str] - template_fields_renderers: Incomplete + template_fields_renderers: ClassVar[dict] ui_color: str sql: Incomplete pass_value: Incomplete @@ -178,7 +178,7 @@ class SQLIntervalCheckOperator(BaseSQLOperator): __mapper_args__: Incomplete template_fields: Sequence[str] template_ext: Sequence[str] - template_fields_renderers: Incomplete + template_fields_renderers: ClassVar[dict] ui_color: str ratio_formulas: Incomplete ratio_formula: Incomplete @@ -208,7 +208,7 @@ class SQLIntervalCheckOperator(BaseSQLOperator): class SQLThresholdCheckOperator(BaseSQLOperator): template_fields: Sequence[str] template_ext: Sequence[str] - template_fields_renderers: Incomplete + template_fields_renderers: ClassVar[dict] sql: Incomplete min_threshold: Incomplete max_threshold: Incomplete @@ -228,7 +228,7 @@ class SQLThresholdCheckOperator(BaseSQLOperator): class BranchSQLOperator(BaseSQLOperator, SkipMixin): template_fields: Sequence[str] template_ext: Sequence[str] - template_fields_renderers: Incomplete + template_fields_renderers: ClassVar[dict] ui_color: str ui_fgcolor: str sql: Incomplete diff --git a/providers/src/airflow/providers/databricks/operators/databricks_sql.py b/providers/src/airflow/providers/databricks/operators/databricks_sql.py index 1975b4de214..7e59fc2a9d5 100644 --- a/providers/src/airflow/providers/databricks/operators/databricks_sql.py +++ b/providers/src/airflow/providers/databricks/operators/databricks_sql.py @@ -21,7 +21,7 @@ from __future__ import annotations import csv import json -from typing import TYPE_CHECKING, Any, Sequence +from typing import TYPE_CHECKING, Any, ClassVar, Sequence from databricks.sql.utils import ParamEscaper @@ -72,7 +72,7 @@ class DatabricksSqlOperator(SQLExecuteQueryOperator): ) template_ext: Sequence[str] = (".sql",) - template_fields_renderers = {"sql": "sql"} + template_fields_renderers: ClassVar[dict] = {"sql": "sql"} conn_id_field = "databricks_conn_id" def __init__( diff --git a/providers/src/airflow/providers/exasol/operators/exasol.py b/providers/src/airflow/providers/exasol/operators/exasol.py index 407fdf65916..51c0131fa5b 100644 --- a/providers/src/airflow/providers/exasol/operators/exasol.py +++ b/providers/src/airflow/providers/exasol/operators/exasol.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from typing import Sequence +from typing import ClassVar, Sequence from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator from airflow.providers.exasol.hooks.exasol import exasol_fetch_all_handler @@ -40,7 +40,7 @@ class ExasolOperator(SQLExecuteQueryOperator): template_fields: Sequence[str] = ("sql", "exasol_conn_id") template_ext: Sequence[str] = (".sql",) - template_fields_renderers = {"sql": "sql"} + template_fields_renderers: ClassVar[dict] = {"sql": "sql"} ui_color = "#ededed" conn_id_field = "exasol_conn_id" diff --git a/providers/src/airflow/providers/jdbc/operators/jdbc.py b/providers/src/airflow/providers/jdbc/operators/jdbc.py index b889eb64518..3357b569c8f 100644 --- a/providers/src/airflow/providers/jdbc/operators/jdbc.py +++ b/providers/src/airflow/providers/jdbc/operators/jdbc.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from typing import Sequence +from typing import ClassVar, Sequence from deprecated import deprecated @@ -54,7 +54,7 @@ class JdbcOperator(SQLExecuteQueryOperator): template_fields: Sequence[str] = ("sql",) template_ext: Sequence[str] = (".sql",) - template_fields_renderers = {"sql": "sql"} + template_fields_renderers: ClassVar[dict] = {"sql": "sql"} ui_color = "#ededed" def __init__(self, *, jdbc_conn_id: str = "jdbc_default", **kwargs) -> None: diff --git a/providers/src/airflow/providers/microsoft/mssql/operators/mssql.py b/providers/src/airflow/providers/microsoft/mssql/operators/mssql.py index 5c24831ef1d..e21c78c3bf6 100644 --- a/providers/src/airflow/providers/microsoft/mssql/operators/mssql.py +++ b/providers/src/airflow/providers/microsoft/mssql/operators/mssql.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from typing import Sequence +from typing import ClassVar, Sequence from deprecated import deprecated @@ -59,7 +59,7 @@ class MsSqlOperator(SQLExecuteQueryOperator): template_fields: Sequence[str] = ("sql",) template_ext: Sequence[str] = (".sql",) - template_fields_renderers = {"sql": "tsql"} + template_fields_renderers: ClassVar[dict] = {"sql": "tsql"} ui_color = "#ededed" def __init__( diff --git a/providers/src/airflow/providers/mysql/operators/mysql.py b/providers/src/airflow/providers/mysql/operators/mysql.py index 2c2436b4d9d..7a47dd0a68c 100644 --- a/providers/src/airflow/providers/mysql/operators/mysql.py +++ b/providers/src/airflow/providers/mysql/operators/mysql.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from typing import Sequence +from typing import ClassVar, Sequence from deprecated import deprecated @@ -58,7 +58,7 @@ class MySqlOperator(SQLExecuteQueryOperator): """ template_fields: Sequence[str] = ("sql", "parameters") - template_fields_renderers = { + template_fields_renderers: ClassVar[dict] = { "sql": "mysql", "parameters": "json", } diff --git a/providers/src/airflow/providers/oracle/operators/oracle.py b/providers/src/airflow/providers/oracle/operators/oracle.py index 0debfed2c6b..5770271d636 100644 --- a/providers/src/airflow/providers/oracle/operators/oracle.py +++ b/providers/src/airflow/providers/oracle/operators/oracle.py @@ -18,7 +18,7 @@ from __future__ import annotations import re -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING, ClassVar, Sequence import oracledb from deprecated import deprecated @@ -60,7 +60,7 @@ class OracleOperator(SQLExecuteQueryOperator): "sql", ) template_ext: Sequence[str] = (".sql",) - template_fields_renderers = {"sql": "sql"} + template_fields_renderers: ClassVar[dict] = {"sql": "sql"} ui_color = "#ededed" def __init__(self, *, oracle_conn_id: str = "oracle_default", **kwargs) -> None: diff --git a/providers/src/airflow/providers/postgres/operators/postgres.py b/providers/src/airflow/providers/postgres/operators/postgres.py index 424a86b6669..f936d462697 100644 --- a/providers/src/airflow/providers/postgres/operators/postgres.py +++ b/providers/src/airflow/providers/postgres/operators/postgres.py @@ -18,7 +18,7 @@ from __future__ import annotations import warnings -from typing import Mapping +from typing import ClassVar, Mapping from deprecated import deprecated @@ -55,7 +55,10 @@ class PostgresOperator(SQLExecuteQueryOperator): Deprecated - use `hook_params={'options': '-c <connection_options>'}` instead. """ - template_fields_renderers = {**SQLExecuteQueryOperator.template_fields_renderers, "sql": "postgresql"} + template_fields_renderers: ClassVar[dict] = { + **SQLExecuteQueryOperator.template_fields_renderers, + "sql": "postgresql", + } ui_color = "#ededed" def __init__( diff --git a/providers/src/airflow/providers/snowflake/operators/snowflake.py b/providers/src/airflow/providers/snowflake/operators/snowflake.py index 89af8fbb6fd..c7c19d740b1 100644 --- a/providers/src/airflow/providers/snowflake/operators/snowflake.py +++ b/providers/src/airflow/providers/snowflake/operators/snowflake.py @@ -19,7 +19,7 @@ from __future__ import annotations import time from datetime import timedelta -from typing import TYPE_CHECKING, Any, Iterable, List, Mapping, Sequence, SupportsAbs, cast +from typing import TYPE_CHECKING, Any, ClassVar, Iterable, List, Mapping, Sequence, SupportsAbs, cast from deprecated import deprecated @@ -88,7 +88,7 @@ class SnowflakeOperator(SQLExecuteQueryOperator): template_fields: Sequence[str] = ("sql",) template_ext: Sequence[str] = (".sql",) - template_fields_renderers = {"sql": "sql"} + template_fields_renderers: ClassVar[dict] = {"sql": "sql"} ui_color = "#ededed" def __init__( diff --git a/providers/src/airflow/providers/sqlite/operators/sqlite.py b/providers/src/airflow/providers/sqlite/operators/sqlite.py index 38c085178f2..1e38696263f 100644 --- a/providers/src/airflow/providers/sqlite/operators/sqlite.py +++ b/providers/src/airflow/providers/sqlite/operators/sqlite.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from typing import Sequence +from typing import ClassVar, Sequence from deprecated import deprecated @@ -51,7 +51,7 @@ class SqliteOperator(SQLExecuteQueryOperator): template_fields: Sequence[str] = ("sql",) template_ext: Sequence[str] = (".sql",) - template_fields_renderers = {"sql": "sql"} + template_fields_renderers: ClassVar[dict] = {"sql": "sql"} ui_color = "#cdaaed" def __init__(self, *, sqlite_conn_id: str = "sqlite_default", **kwargs) -> None: diff --git a/providers/src/airflow/providers/teradata/operators/teradata.py b/providers/src/airflow/providers/teradata/operators/teradata.py index c15fc290385..edb1331c612 100644 --- a/providers/src/airflow/providers/teradata/operators/teradata.py +++ b/providers/src/airflow/providers/teradata/operators/teradata.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING, ClassVar, Sequence from airflow.models import BaseOperator from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator @@ -49,7 +49,7 @@ class TeradataOperator(SQLExecuteQueryOperator): "parameters", ) template_ext: Sequence[str] = (".sql",) - template_fields_renderers = {"sql": "sql"} + template_fields_renderers: ClassVar[dict] = {"sql": "sql"} ui_color = "#e07c24" def __init__( diff --git a/providers/src/airflow/providers/trino/operators/trino.py b/providers/src/airflow/providers/trino/operators/trino.py index 76856728a48..9ff9768d745 100644 --- a/providers/src/airflow/providers/trino/operators/trino.py +++ b/providers/src/airflow/providers/trino/operators/trino.py @@ -19,7 +19,7 @@ from __future__ import annotations -from typing import Any, Sequence +from typing import Any, ClassVar, Sequence from deprecated import deprecated from trino.exceptions import TrinoQueryError @@ -56,7 +56,7 @@ class TrinoOperator(SQLExecuteQueryOperator): """ template_fields: Sequence[str] = ("sql",) - template_fields_renderers = {"sql": "sql"} + template_fields_renderers: ClassVar[dict] = {"sql": "sql"} template_ext: Sequence[str] = (".sql",) ui_color = "#ededed" diff --git a/providers/src/airflow/providers/vertica/operators/vertica.py b/providers/src/airflow/providers/vertica/operators/vertica.py index 6373dfdf8b4..03cf14f000e 100644 --- a/providers/src/airflow/providers/vertica/operators/vertica.py +++ b/providers/src/airflow/providers/vertica/operators/vertica.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from typing import Any, Sequence +from typing import Any, ClassVar, Sequence from deprecated import deprecated @@ -45,7 +45,7 @@ class VerticaOperator(SQLExecuteQueryOperator): template_fields: Sequence[str] = ("sql",) template_ext: Sequence[str] = (".sql",) - template_fields_renderers = {"sql": "sql"} + template_fields_renderers: ClassVar[dict] = {"sql": "sql"} ui_color = "#b4e0ff" def __init__(self, *, vertica_conn_id: str = "vertica_default", **kwargs: Any) -> None: diff --git a/providers/tests/google/cloud/operators/test_bigquery.py b/providers/tests/google/cloud/operators/test_bigquery.py index 8d6f4de47ff..4269d02377b 100644 --- a/providers/tests/google/cloud/operators/test_bigquery.py +++ b/providers/tests/google/cloud/operators/test_bigquery.py @@ -2578,7 +2578,7 @@ class TestBigQueryValueCheckOperator: """ Assert the exception if require param not pass to BigQueryValueCheckOperator with deferrable=True """ - with pytest.raises(TypeError) as missing_param: + with pytest.raises((TypeError, AirflowException)) as missing_param: BigQueryValueCheckOperator(deferrable=True, **kwargs) assert missing_param.value.args[0] == expected @@ -2590,7 +2590,7 @@ class TestBigQueryValueCheckOperator: "missing keyword arguments 'sql', 'pass_value'", "missing keyword arguments 'pass_value', 'sql'", ) - with pytest.raises(TypeError) as missing_param: + with pytest.raises((TypeError, AirflowException)) as missing_param: BigQueryValueCheckOperator(deferrable=True, kwargs={}) assert missing_param.value.args[0] in (expected, expected1) diff --git a/providers/tests/google/cloud/operators/test_compute.py b/providers/tests/google/cloud/operators/test_compute.py index fcfb4e9d48d..7913143618c 100644 --- a/providers/tests/google/cloud/operators/test_compute.py +++ b/providers/tests/google/cloud/operators/test_compute.py @@ -349,7 +349,9 @@ class TestGceInstanceInsertFromTemplate: ) def test_insert_instance_from_template_should_throw_ex_when_missing_source_instance_template(self): - with pytest.raises(TypeError, match=r"missing keyword argument 'source_instance_template'"): + with pytest.raises( + (TypeError, AirflowException), match=r"missing keyword argument 'source_instance_template'" + ): ComputeEngineInsertInstanceFromTemplateOperator( project_id=GCP_PROJECT_ID, body=GCP_INSTANCE_BODY_FROM_TEMPLATE, @@ -360,7 +362,7 @@ class TestGceInstanceInsertFromTemplate: ) def test_insert_instance_from_template_should_throw_ex_when_missing_body(self): - with pytest.raises(TypeError, match=r"missing keyword argument 'body'"): + with pytest.raises((TypeError, AirflowException), match=r"missing keyword argument 'body'"): ComputeEngineInsertInstanceFromTemplateOperator( project_id=GCP_PROJECT_ID, source_instance_template=SOURCE_INSTANCE_TEMPLATE, diff --git a/providers/tests/google/cloud/operators/test_dataflow.py b/providers/tests/google/cloud/operators/test_dataflow.py index 96e0621add8..bc23d84d228 100644 --- a/providers/tests/google/cloud/operators/test_dataflow.py +++ b/providers/tests/google/cloud/operators/test_dataflow.py @@ -1077,7 +1077,7 @@ class TestDataflowRunPipelineOperator: "location": TEST_LOCATION, "gcp_conn_id": GCP_CONN_ID, } - with pytest.raises(TypeError, match="missing keyword argument"): + with pytest.raises((TypeError, AirflowException), match="missing keyword argument"): DataflowRunPipelineOperator(**init_kwargs).execute(mock.MagicMock()).return_value = { "error": {"message": "example error"} }
