This is an automated email from the ASF dual-hosted git repository.
jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 87c55b51457 Introduce notion of dialects in DbApiHook (#41327)
87c55b51457 is described below
commit 87c55b51457bf9dafbcbf541ff51940f0455fd15
Author: David Blain <[email protected]>
AuthorDate: Tue Dec 31 20:57:19 2024 +0100
Introduce notion of dialects in DbApiHook (#41327)
* refactor: Added unit test for handlers module in mssql
* refactor: Added unit test for Dialect class
* refactor: Reformatted unit test of Dialect class
* fix: Added missing import of TYPE_CHECKING
* refactor: Added dialects in provider schema and moved MsSqlDialect to
Microsoft mssql provider
* refactor: Removed duplicate handlers and import them from handlers module
* refactor: Fixed import of TYPE_CHECKING in mssql hook
* refactor: Fixed some static checks and imports
* refactor: Dialect should be defined as an array in provider.yaml, not a
single element
* refactor: Fixed default dialect name for common sql provider
* refactor: Fixed dialect name for Microsoft MSSQL provider
* refactor: Fixed module for dialect in pyton-modules of common sql provider
* refactor: Dialect module is not part of hooks
* refactor: Moved unit tests for default Dialect to common sql provider
instead of Microsoft MSSQL provider
* refactor: Added unit test for MsSqlDialect
* refactor: Reformatted TestMsSqlDialect
* refactor: Implemented dialect resolution using the ProvidersManagers in
DbApiHook
* refactor: Updated comment in dialects property
* refactor: Added dialects lists command
* refactor: Removed unused code from _import_hook method
* refactor: Reformatted _discover_provider_dialects method in
ProvidersManager
* refactor: Removed unused imports from MsSqlHook
* refactor: Removed dialects from DbApiHook definition
* refactor: Reformatted _discover_provider_dialects
* refactor: Renamed module for TestMsSqlDialect
* refactor: test_generate_replace_sql in TestMsSqlHook should only be
tested on Airflow 2.10 or higher
* refactor: Updated expected merge into statement
* refactor: Only run test_generate_replace_sql on TestMsSqlDialect when
Airflow is higher than 2.10
* refactor: generate_insert_sql based on dialects should only be tested on
Airflow 3.0 or higher
* refactor: Updated reason in skipped tests
* refactor: Removed locking in merge into
* refactor: Added kwargs to constructor of Dialect to make it future proof
if additional arguments would be needed in the future
* refactor: Removed row locking clause in generated replace sql statement
and removed pyi file for mssql dialect
* refactor: Implemented PostgresDialect
* fix: Fixed constructor Dialect
* refactor: Register PostgresDialect in providers.yaml and added unit test
for PostgresDialect
* refactor: PostgresHook now uses the dialect to generate statements and
get primary_keys
* refactor: Refactored DbApiHook
* refactor: Refactored the dialect_name mechanism in DbApiHook, override it
in specialized Hooks
* refactor: Fixed some static checks
* refactor: Fixed dialect.pyi
* refactor: Refactored how dialects are resolved, if not found always fall
back to default
* refactor: Reformatted dialect method in DbApiHook
* refactor: Changed message in raised exception of dialect method when not
found
* refactor: Added missing get_records method in Dialect definition
* refactor: Fixed some static checks and mypy issues
* refactor: Raise ValueError if replace_index doesn't exist
* refactor: Increased version of apache-airflow-providers-common-sql to
1.17.1 for mssql and postgres
* refactor: Updated dialect.pyi
* refactor: Updated provider dependencies
* refactor: Incremented version of apache-airflow-providers-common-sql in
test_get_install_requirements
* refactor: Reformatted get_records method
* refactor: Common sql provider must depend on latest Airflow version to be
able to discover dialects through ProvidersManager
* refactor: Updated provider dependencies
* Revert "refactor: Updated provider dependencies"
This reverts commit 2b591f22d67d15700c51f09ba61391a6b73a8a42.
* Revert "refactor: Common sql provider must depend on latest Airflow
version to be able to discover dialects through ProvidersManager"
This reverts commit cb2d043fafb1b614275841849865931fb64e09d4.
* refactor: Added get_dialects method in DbAPiHook which contains fallback
code for Airflow 2.8.x so the provider can still be used with Airflow versions
prior to 3.0.0
* fix: get_dialects isn't a property but a method
* refactor: Refactored get_dialects in DbAPiHook and added unit tests
* refactor: Added unit tests for MsSqlHook related to dialects
* refactor: Added unit tests for PostgresHook related to dialects
* refactor: Fixed some static checks
* refactor: Removed get_dialects method as this wasn't backward compatible,
avoid importing DialectInfo until min required airflow version is 3.0.0 or
higher
* refactor: Re-added missing deprecated methods for backward compatibility
* refactor: Added resolve_dialects method in sql.pyi
* refactor: Reorganized imports in sql module
* refactor: Fixed definition of resolve_dialects in sql.pyi
* refactor: Fixed TestDialect
* refactor: Fixed DbAPi tests and moved tests from DbAPi to Odbc
* refactor: Ignore flake8 F811 error as those redefinitions are there for
backward compatibility
* refactor: Move import of Dialect under TYPE_CHECKING block
* refactor: Fixed TestMsSqlDialect
* refactor: Fixed TestPostgresDialect
* refactor: Reformatted MsSqlHook
* refactor: Added docstring on placeholder property
* refactor: If no dialect is found for given dialect name, then return
default Dialect
* refactor: Try ignoring flake8 F811 error as those redefinitions are there
for backward compatibility
* refactor: Moved Dialect out of TYPE_CHECKING block
* fix: Fixed definition location of dialect in dialect.pyi
* fix: Fixed TestTeradataHook
* refactor: Marked
test_when_provider_min_airflow_version_is_3_0_or_higher_remove_obsolete_code as
db test
* refactor: Removed handler definitions from sql.pyi
* Revert "refactor: Removed handler definitions from sql.pyi"
This reverts commit a93d73c22f89466d16d9fd36d1865e31bf44784d.
* refactor: Removed white line
* refactor: Removed duplicate imports if handlers
* refactor: Fixed some static checks
* refactor: Changed logging level of generated sql statement to INFO in
DbApiHook
* Revert "refactor: Changed logging level of generated sql statement to
INFO in DbApiHook"
This reverts commit c30feafbce53c9b3dc1bfd1e4467de45b83f3e47.
* fix: Moved dialects to correct providers location
* fix: Deleted old providers location
* fix: Re-added missing dialects for mssql and postgres
* fix: Fixed 2 imports for providers tests
* refactored: Reorganized some imports
* refactored: Fixed dialect and sql types
* refactored: Fixed import of test_utils in test_dag_run
* refactored: Added white line in imports of test_dag_run
* refactored: Escape reserved words as column names
* refactored: Fixed initialisation of Dialects
* refactored: Renamed escape_reserved_word to escape_column_name
* refactored: Reformatted TestMsSqlDialect
* refactored: Fixed constructor definition Dialect
* refactored: Fixed TestDbApiHook
* refactored: Removed get_reserved_words from dialect definition
* refactored: Added logging in get_reserved_words method
* refactor: Removed duplicate reserved_words property in DbApiHook
* refactor: Fixed invocation of reserved_words property and changed name of
postgres dialect to postgresql like in sqlalchemy
* refactor: Removed override generate_insert_sql method in PostgresDialect
as it doesn't do anything different than the existing one in Dialect
* refactor: Added unit tests for _generate_insert_sql methods on MsSqlHook
and PostgresHook
* refactor: Reformatted test mssql and test postgres
* refactor: Fixed TestPostgresDialect
* refactor: Refactored get_reserved_words
* refactor: Added escape column name format so we can customize it if needed
* refactor: Suppress NoSuchModuleError exception when trying to load
dialect from sqlalchemy to get reserved words
* refactor: Removed name from Dialect and added unit test for dialect name
in JdbcHook
* refactor: Fixed parameters in get_column_names method of Dialect
* refactor: Added missing optional schema parameter to get_primary_keys
method of MsSqlDialect
* refactor: Fixed TestDialect
* refactor: Fixed TestDbApiHook
* refactor: Fixed TestMsSqlDialect
* refactor: Reformatted test_generate_replace_sql
* refactor: Fixed dialect in MsSqlHook and PostgresHook
* refactor: Fixed TestPostgresDialect
* refactor: Mark TestMySqlHook as a db test
* refactor: Fixed test_generate_insert_sql_with_already_escaped_column_name
in TestPostgresHook
* refactor: Reactivate postgres backend in TestPostgresHook
* refactor: Removed name param of constructor in Dialect definition
* refactor: Reformatted imports for TestMySqlHook
* refactor: Fixed import of get_provider_min_airflow_version in test sql
* refactor: Override default escape_column_name_format for MySqlHook
* refactor: Fixed tests in TestMySqlHook
* refactor: Refactored INSERT_SQL_STATEMENT constant in TestMySqlHook
* refactor: When using ODBC, we should also use the odbc connection when
creating an sqlalchemy engine
* refactor: Added get_target_fields in Dialect which only returns
insertable column_names and added core.dbapihook_resolve_target_fields
configuration parameter to allow to specify if we want to resolve target_fields
automatically or not
* refactor: By default the core.dbapihook_resolve_target_fields
configuration parameter should be False so the original behaviour is respected
* refactor: Added logging statement for target_fields in Dialect
* refactor: Moved _resolve_target_fields as static field of DbApiHook and
fixed TestMsSqlHook
* refactor: Added test for get_sqlalchemy_engine in OdbcHook
* refactor: Reformatted teardown method
* Revert "refactor: Added test for get_sqlalchemy_engine in OdbcHook"
This reverts commit 871e96b5aa7dc0c0413744848fb97dac4ea61166.
* refactor: Remove patched get_sql_alchemy method in OdbcHook, will fix
this in dedicated PR
* refactor: Removed quotes from schema and table_name before invoking
sqlalchemy inspector methods
* refactor: Removed check in test_sql for Airflow 2.8 plus as it is already
at that min required version
* refactor: Fixed get_primary_keys method in PostgresDialect
* refactor: Reformatted get_primary_keys method of PostgresDialect
* refactor: extract_schema_from_table is now a public classmethod of Dialect
* fix: extract_schema_from_table is now a public classmethod of Dialect
* refactor: Reorganized imports
* refactor: Reorganized imports dialect and postgres
* refactor: Fixed test_dialect in TestProviderManager
* refactor: Removed operators section from provider.yaml in mssql and
postgres
* refactor: Removed unused imports in postgres hook
* refactor: Added missing import for AirflowProviderDeprecationWarning
* refactor: Added rowlock option in merge into statement for MSSQL
* refactor: Updated expected replace statement for MSSQL
---------
Co-authored-by: David Blain <[email protected]>
Co-authored-by: David Blain <[email protected]>
---
.../providers/common/sql/dialects/__init__.py | 16 ++
.../providers/common/sql/dialects/dialect.py | 190 ++++++++++++++++++
.../providers/common/sql/dialects/dialect.pyi | 73 +++++++
.../src/airflow/providers/common/sql/hooks/sql.py | 118 +++++++++--
.../src/airflow/providers/common/sql/hooks/sql.pyi | 23 ++-
.../src/airflow/providers/common/sql/provider.yaml | 4 +
.../providers/microsoft/mssql/dialects/__init__.py | 16 ++
.../providers/microsoft/mssql/dialects/mssql.py | 64 ++++++
.../providers/microsoft/mssql/hooks/mssql.py | 67 ++-----
.../providers/microsoft/mssql/provider.yaml | 4 +
.../src/airflow/providers/mysql/hooks/mysql.py | 1 +
.../providers/postgres/dialects/__init__.py | 16 ++
.../providers/postgres/dialects/postgres.py | 91 +++++++++
.../airflow/providers/postgres/hooks/postgres.py | 73 +------
.../src/airflow/providers/postgres/provider.yaml | 4 +
providers/tests/common/sql/dialects/__init__.py | 16 ++
.../tests/common/sql/dialects/test_dialect.py | 67 +++++++
providers/tests/common/sql/hooks/test_dbapi.py | 5 +
providers/tests/common/sql/hooks/test_sql.py | 34 +++-
providers/tests/jdbc/hooks/test_jdbc.py | 25 ++-
.../tests/microsoft/mssql/dialects/__init__.py | 16 ++
.../tests/microsoft/mssql/dialects/test_mssql.py | 83 ++++++++
.../tests/microsoft/mssql/hooks/test_mssql.py | 221 +++++++++++++--------
.../tests/microsoft/mssql/resources/replace.sql | 14 +-
providers/tests/mysql/hooks/test_mysql.py | 85 +++++++-
providers/tests/odbc/hooks/test_odbc.py | 25 +++
providers/tests/postgres/dialects/__init__.py | 16 ++
providers/tests/postgres/dialects/test_postgres.py | 87 ++++++++
providers/tests/postgres/hooks/test_postgres.py | 82 ++++++++
providers/tests/teradata/hooks/test_teradata.py | 9 +-
tests/always/test_providers_manager.py | 3 +-
31 files changed, 1317 insertions(+), 231 deletions(-)
diff --git a/providers/src/airflow/providers/common/sql/dialects/__init__.py
b/providers/src/airflow/providers/common/sql/dialects/__init__.py
new file mode 100644
index 00000000000..13a83393a91
--- /dev/null
+++ b/providers/src/airflow/providers/common/sql/dialects/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/providers/src/airflow/providers/common/sql/dialects/dialect.py
b/providers/src/airflow/providers/common/sql/dialects/dialect.py
new file mode 100644
index 00000000000..184e6a5ce4e
--- /dev/null
+++ b/providers/src/airflow/providers/common/sql/dialects/dialect.py
@@ -0,0 +1,190 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import re
+from collections.abc import Iterable, Mapping
+from typing import TYPE_CHECKING, Any, Callable, TypeVar
+
+from methodtools import lru_cache
+
+from airflow.utils.log.logging_mixin import LoggingMixin
+
+if TYPE_CHECKING:
+ from sqlalchemy.engine import Inspector
+
+T = TypeVar("T")
+
+
+class Dialect(LoggingMixin):
+ """Generic dialect implementation."""
+
+ pattern = re.compile(r'"([a-zA-Z0-9_]+)"')
+
+ def __init__(self, hook, **kwargs) -> None:
+ super().__init__(**kwargs)
+
+ from airflow.providers.common.sql.hooks.sql import DbApiHook
+
+ if not isinstance(hook, DbApiHook):
+ raise TypeError(f"hook must be an instance of
{DbApiHook.__class__.__name__}")
+
+ self.hook: DbApiHook = hook
+
+ @classmethod
+ def remove_quotes(cls, value: str | None) -> str | None:
+ if value:
+ return cls.pattern.sub(r"\1", value)
+
+ @property
+ def placeholder(self) -> str:
+ return self.hook.placeholder
+
+ @property
+ def inspector(self) -> Inspector:
+ return self.hook.inspector
+
+ @property
+ def _insert_statement_format(self) -> str:
+ return self.hook._insert_statement_format # type: ignore
+
+ @property
+ def _replace_statement_format(self) -> str:
+ return self.hook._replace_statement_format # type: ignore
+
+ @property
+ def _escape_column_name_format(self) -> str:
+ return self.hook._escape_column_name_format # type: ignore
+
+ @classmethod
+ def extract_schema_from_table(cls, table: str) -> tuple[str, str | None]:
+ parts = table.split(".")
+ return tuple(parts[::-1]) if len(parts) == 2 else (table, None)
+
+ @lru_cache(maxsize=None)
+ def get_column_names(
+ self, table: str, schema: str | None = None, predicate: Callable[[T],
bool] = lambda column: True
+ ) -> list[str] | None:
+ if schema is None:
+ table, schema = self.extract_schema_from_table(table)
+ column_names = list(
+ column["name"]
+ for column in filter(
+ predicate,
+ self.inspector.get_columns(
+ table_name=self.remove_quotes(table),
+ schema=self.remove_quotes(schema) if schema else None,
+ ),
+ )
+ )
+ self.log.debug("Column names for table '%s': %s", table, column_names)
+ return column_names
+
+ @lru_cache(maxsize=None)
+ def get_target_fields(self, table: str, schema: str | None = None) ->
list[str] | None:
+ target_fields = self.get_column_names(
+ table,
+ schema,
+ lambda column: not column.get("identity", False) and not
column.get("autoincrement", False),
+ )
+ self.log.debug("Target fields for table '%s': %s", table,
target_fields)
+ return target_fields
+
+ @lru_cache(maxsize=None)
+ def get_primary_keys(self, table: str, schema: str | None = None) ->
list[str] | None:
+ if schema is None:
+ table, schema = self.extract_schema_from_table(table)
+ primary_keys = self.inspector.get_pk_constraint(
+ table_name=self.remove_quotes(table),
+ schema=self.remove_quotes(schema) if schema else None,
+ ).get("constrained_columns", [])
+ self.log.debug("Primary keys for table '%s': %s", table, primary_keys)
+ return primary_keys
+
+ def run(
+ self,
+ sql: str | Iterable[str],
+ autocommit: bool = False,
+ parameters: Iterable | Mapping[str, Any] | None = None,
+ handler: Callable[[Any], T] | None = None,
+ split_statements: bool = False,
+ return_last: bool = True,
+ ) -> tuple | list[tuple] | list[list[tuple] | tuple] | None:
+ return self.hook.run(sql, autocommit, parameters, handler,
split_statements, return_last)
+
+ def get_records(
+ self,
+ sql: str | list[str],
+ parameters: Iterable | Mapping[str, Any] | None = None,
+ ) -> Any:
+ return self.hook.get_records(sql=sql, parameters=parameters)
+
+ @property
+ def reserved_words(self) -> set[str]:
+ return self.hook.reserved_words
+
+ def escape_column_name(self, column_name: str) -> str:
+ """
+ Escape the column name if it's a reserved word.
+
+ :param column_name: Name of the column
+ :return: The escaped column name if needed
+ """
+ if (
+ column_name != self._escape_column_name_format.format(column_name)
+ and column_name.casefold() in self.reserved_words
+ ):
+ return self._escape_column_name_format.format(column_name)
+ return column_name
+
+ def _joined_placeholders(self, values) -> str:
+ placeholders = [
+ self.placeholder,
+ ] * len(values)
+ return ",".join(placeholders)
+
+ def _joined_target_fields(self, target_fields) -> str:
+ if target_fields:
+ target_fields = ", ".join(map(self.escape_column_name,
target_fields))
+ return f"({target_fields})"
+ return ""
+
+ def generate_insert_sql(self, table, values, target_fields, **kwargs) ->
str:
+ """
+ Generate the INSERT SQL statement.
+
+ :param table: Name of the target table
+ :param values: The row to insert into the table
+ :param target_fields: The names of the columns to fill in the table
+ :return: The generated INSERT SQL statement
+ """
+ return self._insert_statement_format.format(
+ table, self._joined_target_fields(target_fields),
self._joined_placeholders(values)
+ )
+
+ def generate_replace_sql(self, table, values, target_fields, **kwargs) ->
str:
+ """
+ Generate the REPLACE SQL statement.
+
+ :param table: Name of the target table
+ :param values: The row to insert into the table
+ :param target_fields: The names of the columns to fill in the table
+ :return: The generated REPLACE SQL statement
+ """
+ return self._replace_statement_format.format(
+ table, self._joined_target_fields(target_fields),
self._joined_placeholders(values)
+ )
diff --git a/providers/src/airflow/providers/common/sql/dialects/dialect.pyi
b/providers/src/airflow/providers/common/sql/dialects/dialect.pyi
new file mode 100644
index 00000000000..423fab3ccd0
--- /dev/null
+++ b/providers/src/airflow/providers/common/sql/dialects/dialect.pyi
@@ -0,0 +1,73 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This is automatically generated stub for the `common.sql` provider
+#
+# This file is generated automatically by the `update-common-sql-api stubs`
pre-commit
+# and the .pyi file represents part of the "public" API that the
+# `common.sql` provider exposes to other providers.
+#
+# Any, potentially breaking change in the stubs will require deliberate manual
action from the contributor
+# making a change to the `common.sql` provider. Those stubs are also used by
MyPy automatically when checking
+# if only public API of the common.sql provider is used by all the other
providers.
+#
+# You can read more in the README_API.md file
+#
+"""
+Definition of the public interface for
airflow.providers.common.sql.dialects.dialect
+isort:skip_file
+"""
+from _typeshed import Incomplete as Incomplete
+from airflow.utils.log.logging_mixin import LoggingMixin as LoggingMixin
+from sqlalchemy.engine import Inspector as Inspector
+from typing import Any, Callable, Iterable, Mapping, TypeVar
+
+T = TypeVar("T")
+
+class Dialect(LoggingMixin):
+ hook: Incomplete
+ def __init__(self, hook, **kwargs) -> None: ...
+ @classmethod
+ def remove_quotes(cls, value: str | None) -> str | None: ...
+ @property
+ def placeholder(self) -> str: ...
+ @property
+ def inspector(self) -> Inspector: ...
+ @classmethod
+ def extract_schema_from_table(cls, table: str) -> tuple[str, str | None]:
...
+ def get_column_names(
+ self, table: str, schema: str | None = None, predicate: Callable[[T],
bool] = ...
+ ) -> list[str] | None: ...
+ def get_target_fields(self, table: str, schema: str | None = None) ->
list[str] | None: ...
+ def get_primary_keys(self, table: str, schema: str | None = None) ->
list[str] | None: ...
+ def run(
+ self,
+ sql: str | Iterable[str],
+ autocommit: bool = False,
+ parameters: Iterable | Mapping[str, Any] | None = None,
+ handler: Callable[[Any], T] | None = None,
+ split_statements: bool = False,
+ return_last: bool = True,
+ ) -> tuple | list[tuple] | list[list[tuple] | tuple] | None: ...
+ def get_records(
+ self, sql: str | list[str], parameters: Iterable | Mapping[str, Any] |
None = None
+ ) -> Any: ...
+ @property
+ def reserved_words(self) -> set[str]: ...
+ def escape_column_name(self, column_name: str) -> str: ...
+ def generate_insert_sql(self, table, values, target_fields, **kwargs) ->
str: ...
+ def generate_replace_sql(self, table, values, target_fields, **kwargs) ->
str: ...
diff --git a/providers/src/airflow/providers/common/sql/hooks/sql.py
b/providers/src/airflow/providers/common/sql/hooks/sql.py
index f4d107f0c5f..25d25eaec17 100644
--- a/providers/src/airflow/providers/common/sql/hooks/sql.py
+++ b/providers/src/airflow/providers/common/sql/hooks/sql.py
@@ -18,8 +18,8 @@ from __future__ import annotations
import contextlib
import warnings
-from collections.abc import Generator, Iterable, Mapping, Sequence
-from contextlib import closing, contextmanager
+from collections.abc import Generator, Iterable, Mapping, MutableMapping,
Sequence
+from contextlib import closing, contextmanager, suppress
from datetime import datetime
from functools import cached_property
from typing import (
@@ -34,15 +34,20 @@ from typing import (
from urllib.parse import urlparse
import sqlparse
+from methodtools import lru_cache
from more_itertools import chunked
from sqlalchemy import create_engine
-from sqlalchemy.engine import Inspector
+from sqlalchemy.engine import Inspector, make_url
+from sqlalchemy.exc import ArgumentError, NoSuchModuleError
+from airflow.configuration import conf
from airflow.exceptions import (
AirflowException,
AirflowOptionalProviderFeatureException,
)
from airflow.hooks.base import BaseHook
+from airflow.providers.common.sql.dialects.dialect import Dialect
+from airflow.utils.module_loading import import_string
if TYPE_CHECKING:
from pandas import DataFrame
@@ -83,6 +88,36 @@ def fetch_one_handler(cursor) -> list[tuple] | None:
return handlers.fetch_one_handler(cursor)
+def resolve_dialects() -> MutableMapping[str, MutableMapping]:
+ from airflow.providers_manager import ProvidersManager
+
+ providers_manager = ProvidersManager()
+
+ # TODO: this check can be removed once common sql provider depends on
Airflow 3.0 or higher,
+ # we could then also use DialectInfo and won't need to convert it to
a dict.
+ if hasattr(providers_manager, "dialects"):
+ return {key: dict(value._asdict()) for key, value in
providers_manager.dialects.items()}
+
+ # TODO: this can be removed once common sql provider depends on Airflow
3.0 or higher
+ return {
+ "default": dict(
+ name="default",
+
dialect_class_name="airflow.providers.common.sql.dialects.dialect.Dialect",
+ provider_name="apache-airflow-providers-common-sql",
+ ),
+ "mssql": dict(
+ name="mssql",
+
dialect_class_name="airflow.providers.microsoft.mssql.dialects.mssql.MsSqlDialect",
+ provider_name="apache-airflow-providers-microsoft-mssql",
+ ),
+ "postgresql": dict(
+ name="postgresql",
+
dialect_class_name="airflow.providers.postgres.dialects.postgres.PostgresDialect",
+ provider_name="apache-airflow-providers-postgres",
+ ),
+ }
+
+
class ConnectorProtocol(Protocol):
"""Database connection protocol."""
@@ -129,6 +164,8 @@ class DbApiHook(BaseHook):
_test_connection_sql = "select 1"
# Default SQL placeholder
_placeholder: str = "%s"
+ _dialects: MutableMapping[str, MutableMapping] = resolve_dialects()
+ _resolve_target_fields = conf.getboolean("core",
"dbapihook_resolve_target_fields", fallback=False)
def __init__(self, *args, schema: str | None = None, log_sql: bool = True,
**kwargs):
super().__init__()
@@ -153,6 +190,7 @@ class DbApiHook(BaseHook):
self._replace_statement_format: str = kwargs.get(
"replace_statement_format", "REPLACE INTO {} {} VALUES ({})"
)
+ self._escape_column_name_format: str =
kwargs.get("escape_column_name_format", '"{}"')
self._connection: Connection | None = kwargs.pop("connection", None)
def get_conn_id(self) -> str:
@@ -262,6 +300,57 @@ class DbApiHook(BaseHook):
def inspector(self) -> Inspector:
return Inspector.from_engine(self.get_sqlalchemy_engine())
+ @cached_property
+ def dialect_name(self) -> str:
+ try:
+ return make_url(self.get_uri()).get_dialect().name
+ except (ArgumentError, NoSuchModuleError):
+ config = self.connection_extra
+ sqlalchemy_scheme = config.get("sqlalchemy_scheme")
+ if sqlalchemy_scheme:
+ return sqlalchemy_scheme.split("+")[0] if "+" in
sqlalchemy_scheme else sqlalchemy_scheme
+ return config.get("dialect", "default")
+
+ @cached_property
+ def dialect(self) -> Dialect:
+ from airflow.utils.module_loading import import_string
+
+ dialect_info = self._dialects.get(self.dialect_name)
+
+ self.log.debug("dialect_info: %s", dialect_info)
+
+ if dialect_info:
+ try:
+ return import_string(dialect_info["dialect_class_name"])(self)
+ except ImportError:
+ raise AirflowOptionalProviderFeatureException(
+ f"{dialect_info.dialect_class_name} not found, run: pip
install "
+ f"'{dialect_info.provider_name}'."
+ )
+ return Dialect(self)
+
+ @property
+ def reserved_words(self) -> set[str]:
+ return self.get_reserved_words(self.dialect_name)
+
+ @lru_cache(maxsize=None)
+ def get_reserved_words(self, dialect_name: str) -> set[str]:
+ result = set()
+ with suppress(ImportError, ModuleNotFoundError, NoSuchModuleError):
+ dialect_module =
import_string(f"sqlalchemy.dialects.{dialect_name}.base")
+
+ if hasattr(dialect_module, "RESERVED_WORDS"):
+ result = set(dialect_module.RESERVED_WORDS)
+ else:
+ dialect_module =
import_string(f"sqlalchemy.dialects.{dialect_name}.reserved_words")
+ reserved_words_attr = f"RESERVED_WORDS_{dialect_name.upper()}"
+
+ if hasattr(dialect_module, reserved_words_attr):
+ result = set(getattr(dialect_module, reserved_words_attr))
+
+ self.log.debug("reserved words for '%s': %s", dialect_name, result)
+ return result
+
def get_pandas_df(
self,
sql,
@@ -543,7 +632,7 @@ class DbApiHook(BaseHook):
"""Return a cursor."""
return self.get_conn().cursor()
- def _generate_insert_sql(self, table, values, target_fields, replace,
**kwargs) -> str:
+ def _generate_insert_sql(self, table, values, target_fields=None, replace:
bool = False, **kwargs) -> str:
"""
Generate the INSERT SQL statement.
@@ -551,24 +640,19 @@ class DbApiHook(BaseHook):
:param table: Name of the target table
:param values: The row to insert into the table
- :param target_fields: The names of the columns to fill in the table
+ :param target_fields: The names of the columns to fill in the table.
If no target fields are
+ specified, they will be determined dynamically from the table's
metadata.
:param replace: Whether to replace/upsert instead of insert
:return: The generated INSERT or REPLACE/UPSERT SQL statement
"""
- placeholders = [
- self.placeholder,
- ] * len(values)
-
- if target_fields:
- target_fields = ", ".join(target_fields)
- target_fields = f"({target_fields})"
- else:
- target_fields = ""
+ if not target_fields and self._resolve_target_fields:
+ with suppress(Exception):
+ target_fields = self.dialect.get_target_fields(table)
- if not replace:
- return self._insert_statement_format.format(table, target_fields,
",".join(placeholders))
+ if replace:
+ return self.dialect.generate_replace_sql(table, values,
target_fields, **kwargs)
- return self._replace_statement_format.format(table, target_fields,
",".join(placeholders))
+ return self.dialect.generate_insert_sql(table, values, target_fields,
**kwargs)
@contextmanager
def _create_autocommit_connection(self, autocommit: bool = False):
diff --git a/providers/src/airflow/providers/common/sql/hooks/sql.pyi
b/providers/src/airflow/providers/common/sql/hooks/sql.pyi
index ed93958401e..afa9754a8b6 100644
--- a/providers/src/airflow/providers/common/sql/hooks/sql.pyi
+++ b/providers/src/airflow/providers/common/sql/hooks/sql.pyi
@@ -34,19 +34,33 @@ isort:skip_file
from _typeshed import Incomplete as Incomplete
from airflow.hooks.base import BaseHook as BaseHook
from airflow.models import Connection as Connection
+from airflow.providers.common.sql.dialects.dialect import Dialect as Dialect
from airflow.providers.openlineage.extractors import OperatorLineage as
OperatorLineage
from airflow.providers.openlineage.sqlparser import DatabaseInfo as
DatabaseInfo
from functools import cached_property as cached_property
from pandas import DataFrame as DataFrame
from sqlalchemy.engine import Inspector as Inspector, URL as URL
-from typing import Any, Callable, Generator, Iterable, Mapping, Protocol,
Sequence, TypeVar, overload
+from typing import (
+ Any,
+ Callable,
+ Generator,
+ Iterable,
+ Mapping,
+ MutableMapping,
+ Protocol,
+ Sequence,
+ TypeVar,
+ overload,
+)
T = TypeVar("T")
SQL_PLACEHOLDERS: Incomplete
+WARNING_MESSAGE: str
def return_single_query_results(sql: str | Iterable[str], return_last: bool,
split_statements: bool): ...
def fetch_all_handler(cursor) -> list[tuple] | None: ...
def fetch_one_handler(cursor) -> list[tuple] | None: ...
+def resolve_dialects() -> MutableMapping[str, MutableMapping]: ...
class ConnectorProtocol(Protocol):
def connect(self, host: str, port: int, username: str, schema: str) ->
Any: ...
@@ -79,6 +93,13 @@ class DbApiHook(BaseHook):
def get_sqlalchemy_engine(self, engine_kwargs: Incomplete | None = None):
...
@property
def inspector(self) -> Inspector: ...
+ @cached_property
+ def dialect_name(self) -> str: ...
+ @cached_property
+ def dialect(self) -> Dialect: ...
+ @property
+ def reserved_words(self) -> set[str]: ...
+ def get_reserved_words(self, dialect_name: str) -> set[str]: ...
def get_pandas_df(
self, sql, parameters: list | tuple | Mapping[str, Any] | None = None,
**kwargs
) -> DataFrame: ...
diff --git a/providers/src/airflow/providers/common/sql/provider.yaml
b/providers/src/airflow/providers/common/sql/provider.yaml
index 32bfbe2d493..530cc351882 100644
--- a/providers/src/airflow/providers/common/sql/provider.yaml
+++ b/providers/src/airflow/providers/common/sql/provider.yaml
@@ -93,6 +93,10 @@ operators:
python-modules:
- airflow.providers.common.sql.operators.sql
+dialects:
+ - dialect-type: default
+ dialect-class-name: airflow.providers.common.sql.dialects.dialect.Dialect
+
hooks:
- integration-name: Common SQL
python-modules:
diff --git
a/providers/src/airflow/providers/microsoft/mssql/dialects/__init__.py
b/providers/src/airflow/providers/microsoft/mssql/dialects/__init__.py
new file mode 100644
index 00000000000..13a83393a91
--- /dev/null
+++ b/providers/src/airflow/providers/microsoft/mssql/dialects/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/providers/src/airflow/providers/microsoft/mssql/dialects/mssql.py
b/providers/src/airflow/providers/microsoft/mssql/dialects/mssql.py
new file mode 100644
index 00000000000..fc2110a762d
--- /dev/null
+++ b/providers/src/airflow/providers/microsoft/mssql/dialects/mssql.py
@@ -0,0 +1,64 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from methodtools import lru_cache
+
+from airflow.providers.common.sql.dialects.dialect import Dialect
+from airflow.providers.common.sql.hooks.handlers import fetch_all_handler
+
+
+class MsSqlDialect(Dialect):
+ """Microsoft SQL Server dialect implementation."""
+
+ @lru_cache(maxsize=None)
+ def get_primary_keys(self, table: str, schema: str | None = None) ->
list[str] | None:
+ primary_keys = self.run(
+ f"""
+ SELECT c.name
+ FROM sys.columns c
+ WHERE c.object_id = OBJECT_ID('{table}')
+ AND EXISTS (SELECT 1 FROM sys.index_columns ic
+ INNER JOIN sys.indexes i ON ic.object_id = i.object_id
AND ic.index_id = i.index_id
+ WHERE i.is_primary_key = 1
+ AND ic.object_id = c.object_id
+ AND ic.column_id = c.column_id);
+ """,
+ handler=fetch_all_handler,
+ )
+ primary_keys = [pk[0] for pk in primary_keys] if primary_keys else []
# type: ignore
+ self.log.debug("Primary keys for table '%s': %s", table, primary_keys)
+ return primary_keys # type: ignore
+
+ def generate_replace_sql(self, table, values, target_fields, **kwargs) ->
str:
+ primary_keys = self.get_primary_keys(table)
+ columns = [
+ self.escape_column_name(target_field)
+ for target_field in target_fields
+ if target_field in set(target_fields).difference(set(primary_keys))
+ ]
+
+ self.log.debug("primary_keys: %s", primary_keys)
+ self.log.debug("columns: %s", columns)
+
+ return f"""MERGE INTO {table} WITH (ROWLOCK) AS target
+ USING (SELECT {', '.join(map(lambda column: f'{self.placeholder}
AS {column}', target_fields))}) AS source
+ ON {' AND '.join(map(lambda column:
f'target.{self.escape_column_name(column)} = source.{column}', primary_keys))}
+ WHEN MATCHED THEN
+ UPDATE SET {', '.join(map(lambda column: f'target.{column} =
source.{column}', columns))}
+ WHEN NOT MATCHED THEN
+ INSERT ({', '.join(target_fields)}) VALUES ({',
'.join(map(lambda column: f'source.{self.escape_column_name(column)}',
target_fields))});"""
diff --git a/providers/src/airflow/providers/microsoft/mssql/hooks/mssql.py
b/providers/src/airflow/providers/microsoft/mssql/hooks/mssql.py
index a367250ed33..089f1ccfb7d 100644
--- a/providers/src/airflow/providers/microsoft/mssql/hooks/mssql.py
+++ b/providers/src/airflow/providers/microsoft/mssql/hooks/mssql.py
@@ -19,13 +19,16 @@
from __future__ import annotations
-from typing import Any
+from typing import TYPE_CHECKING, Any
import pymssql
-from methodtools import lru_cache
from pymssql import Connection as PymssqlConnection
-from airflow.providers.common.sql.hooks.sql import DbApiHook, fetch_all_handler
+from airflow.providers.common.sql.hooks.sql import DbApiHook
+from airflow.providers.microsoft.mssql.dialects.mssql import MsSqlDialect
+
+if TYPE_CHECKING:
+ from airflow.providers.common.sql.dialects.dialect import Dialect
class MsSqlHook(DbApiHook):
@@ -63,6 +66,14 @@ class MsSqlHook(DbApiHook):
raise RuntimeError("sqlalchemy_scheme in connection extra should
not contain : or / characters")
return self._sqlalchemy_scheme or extra_scheme or
self.DEFAULT_SQLALCHEMY_SCHEME
+ @property
+ def dialect_name(self) -> str:
+ return "mssql"
+
+ @property
+ def dialect(self) -> Dialect:
+ return MsSqlDialect(self)
+
def get_uri(self) -> str:
from urllib.parse import parse_qs, urlencode, urlsplit, urlunsplit
@@ -84,56 +95,6 @@ class MsSqlHook(DbApiHook):
engine = self.get_sqlalchemy_engine(engine_kwargs=engine_kwargs)
return engine.connect(**(connect_kwargs or {}))
- @lru_cache(maxsize=None)
- def get_primary_keys(self, table: str) -> list[str]:
- primary_keys = self.run(
- f"""
- SELECT c.name
- FROM sys.columns c
- WHERE c.object_id = OBJECT_ID('{table}')
- AND EXISTS (SELECT 1 FROM sys.index_columns ic
- INNER JOIN sys.indexes i ON ic.object_id = i.object_id AND
ic.index_id = i.index_id
- WHERE i.is_primary_key = 1
- AND ic.object_id = c.object_id
- AND ic.column_id = c.column_id);
- """,
- handler=fetch_all_handler,
- )
- return [pk[0] for pk in primary_keys] # type: ignore
-
- def _generate_insert_sql(self, table, values, target_fields, replace,
**kwargs) -> str:
- """
- Generate the INSERT SQL statement.
-
- The MERGE INTO variant is specific to MSSQL syntax
-
- :param table: Name of the target table
- :param values: The row to insert into the table
- :param target_fields: The names of the columns to fill in the table
- :param replace: Whether to replace/merge into instead of insert
- :return: The generated INSERT or MERGE INTO SQL statement
- """
- if not replace:
- return super()._generate_insert_sql(table, values, target_fields,
replace, **kwargs) # type: ignore
-
- primary_keys = self.get_primary_keys(table)
- columns = [
- target_field
- for target_field in target_fields
- if target_field in set(target_fields).difference(set(primary_keys))
- ]
-
- self.log.debug("primary_keys: %s", primary_keys)
- self.log.info("columns: %s", columns)
-
- return f"""MERGE INTO {table} AS target
- USING (SELECT {', '.join(map(lambda column: f'{self.placeholder} AS
{column}', target_fields))}) AS source
- ON {' AND '.join(map(lambda column: f'target.{column} =
source.{column}', primary_keys))}
- WHEN MATCHED THEN
- UPDATE SET {', '.join(map(lambda column: f'target.{column} =
source.{column}', columns))}
- WHEN NOT MATCHED THEN
- INSERT ({', '.join(target_fields)}) VALUES ({', '.join(map(lambda
column: f'source.{column}', target_fields))});"""
-
def get_conn(self) -> PymssqlConnection:
"""Return ``pymssql`` connection object."""
conn = self.connection
diff --git a/providers/src/airflow/providers/microsoft/mssql/provider.yaml
b/providers/src/airflow/providers/microsoft/mssql/provider.yaml
index b5d0e634798..d29ee70d8a5 100644
--- a/providers/src/airflow/providers/microsoft/mssql/provider.yaml
+++ b/providers/src/airflow/providers/microsoft/mssql/provider.yaml
@@ -72,6 +72,10 @@ integrations:
- /docs/apache-airflow-providers-microsoft-mssql/operators.rst
tags: [software]
+dialects:
+ - dialect-type: mssql
+ dialect-class-name:
airflow.providers.microsoft.mssql.dialects.mssql.MsSqlDialect
+
hooks:
- integration-name: Microsoft SQL Server (MSSQL)
python-modules:
diff --git a/providers/src/airflow/providers/mysql/hooks/mysql.py
b/providers/src/airflow/providers/mysql/hooks/mysql.py
index 5ed8a62d75f..48185c1cf55 100644
--- a/providers/src/airflow/providers/mysql/hooks/mysql.py
+++ b/providers/src/airflow/providers/mysql/hooks/mysql.py
@@ -82,6 +82,7 @@ class MySqlHook(DbApiHook):
self.schema = kwargs.pop("schema", None)
self.local_infile = kwargs.pop("local_infile", False)
self.init_command = kwargs.pop("init_command", None)
+ self._escape_column_name_format: str =
kwargs.get("escape_column_name_format", "`{}`")
def set_autocommit(self, conn: MySQLConnectionTypes, autocommit: bool) ->
None:
"""
diff --git a/providers/src/airflow/providers/postgres/dialects/__init__.py
b/providers/src/airflow/providers/postgres/dialects/__init__.py
new file mode 100644
index 00000000000..13a83393a91
--- /dev/null
+++ b/providers/src/airflow/providers/postgres/dialects/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/providers/src/airflow/providers/postgres/dialects/postgres.py
b/providers/src/airflow/providers/postgres/dialects/postgres.py
new file mode 100644
index 00000000000..5db4cca18f8
--- /dev/null
+++ b/providers/src/airflow/providers/postgres/dialects/postgres.py
@@ -0,0 +1,91 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from methodtools import lru_cache
+
+from airflow.providers.common.sql.dialects.dialect import Dialect
+
+
+class PostgresDialect(Dialect):
+ """Postgres dialect implementation."""
+
+ @property
+ def name(self) -> str:
+ return "postgresql"
+
+ @lru_cache(maxsize=None)
+ def get_primary_keys(self, table: str, schema: str | None = None) ->
list[str] | None:
+ """
+ Get the table's primary key.
+
+ :param table: Name of the target table
+ :param schema: Name of the target schema, public by default
+ :return: Primary key columns list
+ """
+ if schema is None:
+ table, schema = self.extract_schema_from_table(table)
+ sql = """
+ select kcu.column_name
+ from information_schema.table_constraints tco
+ join information_schema.key_column_usage kcu
+ on kcu.constraint_name = tco.constraint_name
+ and kcu.constraint_schema = tco.constraint_schema
+ and kcu.constraint_name = tco.constraint_name
+ where tco.constraint_type = 'PRIMARY KEY'
+ and kcu.table_schema = %s
+ and kcu.table_name = %s
+ """
+ pk_columns = [
+ row[0] for row in self.get_records(sql,
(self.remove_quotes(schema), self.remove_quotes(table)))
+ ]
+ return pk_columns or None
+
+ def generate_replace_sql(self, table, values, target_fields, **kwargs) ->
str:
+ """
+ Generate the REPLACE SQL statement.
+
+ :param table: Name of the target table
+ :param values: The row to insert into the table
+ :param target_fields: The names of the columns to fill in the table
+ :param replace: Whether to replace instead of insert
+ :param replace_index: the column or list of column names to act as
+ index for the ON CONFLICT clause
+ :return: The generated INSERT or REPLACE SQL statement
+ """
+ if not target_fields:
+ raise ValueError("PostgreSQL ON CONFLICT upsert syntax requires
column names")
+
+ replace_index = kwargs.get("replace_index") or
self.get_primary_keys(table)
+
+ if not replace_index:
+ raise ValueError("PostgreSQL ON CONFLICT upsert syntax requires an
unique index")
+
+ if isinstance(replace_index, str):
+ replace_index = [replace_index]
+
+ sql = self.generate_insert_sql(table, values, target_fields, **kwargs)
+ on_conflict_str = f" ON CONFLICT ({',
'.join(map(self.escape_column_name, replace_index))})"
+ replace_target = [self.escape_column_name(f) for f in target_fields if
f not in replace_index]
+
+ if replace_target:
+ replace_target_str = ", ".join(f"{col} = excluded.{col}" for col
in replace_target)
+ sql += f"{on_conflict_str} DO UPDATE SET {replace_target_str}"
+ else:
+ sql += f"{on_conflict_str} DO NOTHING"
+
+ return sql
diff --git a/providers/src/airflow/providers/postgres/hooks/postgres.py
b/providers/src/airflow/providers/postgres/hooks/postgres.py
index 9b657c14416..e760ebaab45 100644
--- a/providers/src/airflow/providers/postgres/hooks/postgres.py
+++ b/providers/src/airflow/providers/postgres/hooks/postgres.py
@@ -18,7 +18,6 @@
from __future__ import annotations
import os
-from collections.abc import Iterable
from contextlib import closing
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Union
@@ -31,11 +30,13 @@ from sqlalchemy.engine import URL
from airflow.exceptions import AirflowException
from airflow.providers.common.sql.hooks.sql import DbApiHook
+from airflow.providers.postgres.dialects.postgres import PostgresDialect
if TYPE_CHECKING:
from psycopg2.extensions import connection
from airflow.models.connection import Connection
+ from airflow.providers.common.sql.dialects.dialect import Dialect
from airflow.providers.openlineage.sqlparser import DatabaseInfo
CursorType = Union[DictCursor, RealDictCursor, NamedTupleCursor]
@@ -123,6 +124,14 @@ class PostgresHook(DbApiHook):
query=query,
)
+ @property
+ def dialect_name(self) -> str:
+ return "postgresql"
+
+ @property
+ def dialect(self) -> Dialect:
+ return PostgresDialect(self)
+
def _get_cursor(self, raw_cursor: str) -> CursorType:
_cursor = raw_cursor.lower()
cursor_types = {
@@ -286,67 +295,7 @@ class PostgresHook(DbApiHook):
:param schema: Name of the target schema, public by default
:return: Primary key columns list
"""
- sql = """
- select kcu.column_name
- from information_schema.table_constraints tco
- join information_schema.key_column_usage kcu
- on kcu.constraint_name = tco.constraint_name
- and kcu.constraint_schema = tco.constraint_schema
- and kcu.constraint_name = tco.constraint_name
- where tco.constraint_type = 'PRIMARY KEY'
- and kcu.table_schema = %s
- and kcu.table_name = %s
- """
- pk_columns = [row[0] for row in self.get_records(sql, (schema, table))]
- return pk_columns or None
-
- def _generate_insert_sql(
- self, table: str, values: tuple[str, ...], target_fields:
Iterable[str], replace: bool, **kwargs
- ) -> str:
- """
- Generate the INSERT SQL statement.
-
- The REPLACE variant is specific to the PostgreSQL syntax.
-
- :param table: Name of the target table
- :param values: The row to insert into the table
- :param target_fields: The names of the columns to fill in the table
- :param replace: Whether to replace instead of insert
- :param replace_index: the column or list of column names to act as
- index for the ON CONFLICT clause
- :return: The generated INSERT or REPLACE SQL statement
- """
- placeholders = [
- self.placeholder,
- ] * len(values)
- replace_index = kwargs.get("replace_index")
-
- if target_fields:
- target_fields_fragment = ", ".join(target_fields)
- target_fields_fragment = f"({target_fields_fragment})"
- else:
- target_fields_fragment = ""
-
- sql = f"INSERT INTO {table} {target_fields_fragment} VALUES
({','.join(placeholders)})"
-
- if replace:
- if not target_fields:
- raise ValueError("PostgreSQL ON CONFLICT upsert syntax
requires column names")
- if not replace_index:
- raise ValueError("PostgreSQL ON CONFLICT upsert syntax
requires an unique index")
- if isinstance(replace_index, str):
- replace_index = [replace_index]
-
- on_conflict_str = f" ON CONFLICT ({', '.join(replace_index)})"
- replace_target = [f for f in target_fields if f not in
replace_index]
-
- if replace_target:
- replace_target_str = ", ".join(f"{col} = excluded.{col}" for
col in replace_target)
- sql += f"{on_conflict_str} DO UPDATE SET {replace_target_str}"
- else:
- sql += f"{on_conflict_str} DO NOTHING"
-
- return sql
+ return self.dialect.get_primary_keys(table=table, schema=schema)
def get_openlineage_database_info(self, connection) -> DatabaseInfo:
"""Return Postgres/Redshift specific information for OpenLineage."""
diff --git a/providers/src/airflow/providers/postgres/provider.yaml
b/providers/src/airflow/providers/postgres/provider.yaml
index 13425797e9c..3b2b0941194 100644
--- a/providers/src/airflow/providers/postgres/provider.yaml
+++ b/providers/src/airflow/providers/postgres/provider.yaml
@@ -86,6 +86,10 @@ integrations:
logo: /integration-logos/postgres/Postgres.png
tags: [software]
+dialects:
+ - dialect-type: postgresql
+ dialect-class-name:
airflow.providers.postgres.dialects.postgres.PostgresDialect
+
hooks:
- integration-name: PostgreSQL
python-modules:
diff --git a/providers/tests/common/sql/dialects/__init__.py
b/providers/tests/common/sql/dialects/__init__.py
new file mode 100644
index 00000000000..13a83393a91
--- /dev/null
+++ b/providers/tests/common/sql/dialects/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/providers/tests/common/sql/dialects/test_dialect.py
b/providers/tests/common/sql/dialects/test_dialect.py
new file mode 100644
index 00000000000..1021b1c617c
--- /dev/null
+++ b/providers/tests/common/sql/dialects/test_dialect.py
@@ -0,0 +1,67 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import MagicMock
+
+from sqlalchemy.engine import Inspector
+
+from airflow.providers.common.sql.dialects.dialect import Dialect
+from airflow.providers.common.sql.hooks.sql import DbApiHook
+
+
+class TestDialect:
+ def setup_method(self):
+ inspector = MagicMock(spc=Inspector)
+ inspector.get_columns.side_effect = lambda table_name, schema: [
+ {"name": "id", "identity": True},
+ {"name": "name"},
+ {"name": "firstname"},
+ {"name": "age"},
+ ]
+ inspector.get_pk_constraint.side_effect = lambda table_name, schema:
{"constrained_columns": ["id"]}
+ self.test_db_hook = MagicMock(placeholder="?", inspector=inspector,
spec=DbApiHook)
+
+ def test_remove_quotes(self):
+ assert not Dialect.remove_quotes(None)
+ assert Dialect.remove_quotes("table") == "table"
+ assert Dialect.remove_quotes('"table"') == "table"
+
+ def test_placeholder(self):
+ assert Dialect(self.test_db_hook).placeholder == "?"
+
+ def test_extract_schema_from_table(self):
+ assert Dialect.extract_schema_from_table("schema.table") == ("table",
"schema")
+
+ def test_get_column_names(self):
+ assert Dialect(self.test_db_hook).get_column_names("table", "schema")
== [
+ "id",
+ "name",
+ "firstname",
+ "age",
+ ]
+
+ def test_get_target_fields(self):
+ assert Dialect(self.test_db_hook).get_target_fields("table", "schema")
== [
+ "name",
+ "firstname",
+ "age",
+ ]
+
+ def test_get_primary_keys(self):
+ assert Dialect(self.test_db_hook).get_primary_keys("table", "schema")
== ["id"]
diff --git a/providers/tests/common/sql/hooks/test_dbapi.py
b/providers/tests/common/sql/hooks/test_dbapi.py
index 1f3f39aa451..57a74987dfe 100644
--- a/providers/tests/common/sql/hooks/test_dbapi.py
+++ b/providers/tests/common/sql/hooks/test_dbapi.py
@@ -28,6 +28,7 @@ from pyodbc import Cursor
from airflow.config_templates.airflow_local_settings import
DEFAULT_LOGGING_CONFIG
from airflow.hooks.base import BaseHook
from airflow.models import Connection
+from airflow.providers.common.sql.dialects.dialect import Dialect
from airflow.providers.common.sql.hooks.sql import DbApiHook,
fetch_all_handler, fetch_one_handler
@@ -62,6 +63,10 @@ class TestDbApiHook:
def get_conn(self):
return conn
+ @property
+ def dialect(self):
+ return Dialect(self)
+
def get_db_log_messages(self, conn) -> None:
return conn.get_messages()
diff --git a/providers/tests/common/sql/hooks/test_sql.py
b/providers/tests/common/sql/hooks/test_sql.py
index 756663ca39c..cb7696b3701 100644
--- a/providers/tests/common/sql/hooks/test_sql.py
+++ b/providers/tests/common/sql/hooks/test_sql.py
@@ -18,6 +18,7 @@
#
from __future__ import annotations
+import inspect
import logging
import logging.config
from unittest.mock import MagicMock
@@ -25,11 +26,14 @@ from unittest.mock import MagicMock
import pytest
from airflow.config_templates.airflow_local_settings import
DEFAULT_LOGGING_CONFIG
+from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.models import Connection
-from airflow.providers.common.sql.hooks.sql import DbApiHook, fetch_all_handler
+from airflow.providers.common.sql.dialects.dialect import Dialect
+from airflow.providers.common.sql.hooks.sql import DbApiHook,
fetch_all_handler, resolve_dialects
from airflow.utils.session import provide_session
from providers.tests.common.sql.test_utils import mock_hook
+from tests_common.test_utils.providers import get_provider_min_airflow_version
TASK_ID = "sql-operator"
HOST = "host"
@@ -259,6 +263,34 @@ class TestDbApiHook:
assert dbapi_hook.placeholder == "%s"
assert dbapi_hook.connection_invocations == 1
+ @pytest.mark.db_test
+ def test_dialect_name(self):
+ dbapi_hook = mock_hook(DbApiHook)
+ assert dbapi_hook.dialect_name == "default"
+
+ @pytest.mark.db_test
+ def test_dialect(self):
+ dbapi_hook = mock_hook(DbApiHook)
+ assert isinstance(dbapi_hook.dialect, Dialect)
+
+ @pytest.mark.db_test
+ def
test_when_provider_min_airflow_version_is_3_0_or_higher_remove_obsolete_code(self):
+ """
+ Once this test starts failing due to the fact that the minimum Airflow
version is now 3.0.0 or higher
+ for this provider, you should remove the obsolete code in the
get_dialects method of the DbApiHook
+ and remove this test. This test was added to make sure to not forget
to remove the fallback code
+ for backward compatibility with Airflow 2.8.x which isn't need anymore
once this provider depends on
+ Airflow 3.0.0 or higher.
+ """
+ min_airflow_version =
get_provider_min_airflow_version("apache-airflow-providers-common-sql")
+
+ # Check if the current Airflow version is 3.0.0 or higher
+ if min_airflow_version[0] >= 3:
+ method_source = inspect.getsource(resolve_dialects)
+ raise AirflowProviderDeprecationWarning(
+ f"Check TODO's to remove obsolete code in resolve_dialects
method:\n\r\n\r\t\t\t{method_source}"
+ )
+
@pytest.mark.db_test
def test_uri(self):
dbapi_hook = mock_hook(DbApiHook)
diff --git a/providers/tests/jdbc/hooks/test_jdbc.py
b/providers/tests/jdbc/hooks/test_jdbc.py
index ce4e5266234..646b3e9c09e 100644
--- a/providers/tests/jdbc/hooks/test_jdbc.py
+++ b/providers/tests/jdbc/hooks/test_jdbc.py
@@ -44,17 +44,27 @@ logger = logging.getLogger(__name__)
def get_hook(
hook_params=None,
conn_params=None,
+ conn_type: str | None = None,
login: str | None = "login",
password: str | None = "password",
host: str | None = "host",
schema: str | None = "schema",
port: int | None = 1234,
+ uri: str | None = None,
):
hook_params = hook_params or {}
conn_params = conn_params or {}
connection = Connection(
**{
- **dict(login=login, password=password, host=host, schema=schema,
port=port),
+ **dict(
+ conn_type=conn_type,
+ login=login,
+ password=password,
+ host=host,
+ schema=schema,
+ port=port,
+ uri=uri,
+ ),
**conn_params,
}
)
@@ -251,6 +261,19 @@ class TestJdbcHook:
engine = jdbc_hook.get_sqlalchemy_engine()
assert engine.connect().connection.connection == connection
+ def test_dialect_name(self):
+ jdbc_hook = get_hook(
+ conn_params=dict(extra={"sqlalchemy_scheme": "hana"}),
+ conn_type="jdbc",
+ login=None,
+ password=None,
+ host="localhost",
+ schema="sap",
+ port=30215,
+ )
+
+ assert jdbc_hook.dialect_name == "hana"
+
def test_get_conn_thread_safety(self):
mock_conn = MagicMock()
open_connections = 0
diff --git a/providers/tests/microsoft/mssql/dialects/__init__.py
b/providers/tests/microsoft/mssql/dialects/__init__.py
new file mode 100644
index 00000000000..13a83393a91
--- /dev/null
+++ b/providers/tests/microsoft/mssql/dialects/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/providers/tests/microsoft/mssql/dialects/test_mssql.py
b/providers/tests/microsoft/mssql/dialects/test_mssql.py
new file mode 100644
index 00000000000..762c4c463df
--- /dev/null
+++ b/providers/tests/microsoft/mssql/dialects/test_mssql.py
@@ -0,0 +1,83 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import MagicMock
+
+from sqlalchemy.engine import Inspector
+
+from airflow.providers.common.sql.hooks.sql import DbApiHook
+from airflow.providers.microsoft.mssql.dialects.mssql import MsSqlDialect
+
+
+class TestMsSqlDialect:
+ def setup_method(self):
+ inspector = MagicMock(spc=Inspector)
+ inspector.get_columns.side_effect = lambda table_name, schema: [
+ {"name": "id", "identity": True},
+ {"name": "name"},
+ {"name": "firstname"},
+ {"name": "age"},
+ ]
+ self.test_db_hook = MagicMock(placeholder="?", inspector=inspector,
spec=DbApiHook)
+ self.test_db_hook.run.side_effect = lambda *args: [("id",)]
+ self.test_db_hook._escape_column_name_format = '"{}"'
+
+ def test_placeholder(self):
+ assert MsSqlDialect(self.test_db_hook).placeholder == "?"
+
+ def test_get_column_names(self):
+ assert
MsSqlDialect(self.test_db_hook).get_column_names("hollywood.actors") == [
+ "id",
+ "name",
+ "firstname",
+ "age",
+ ]
+
+ def test_get_target_fields(self):
+ assert
MsSqlDialect(self.test_db_hook).get_target_fields("hollywood.actors") == [
+ "name",
+ "firstname",
+ "age",
+ ]
+
+ def test_get_primary_keys(self):
+ assert
MsSqlDialect(self.test_db_hook).get_primary_keys("hollywood.actors") == ["id"]
+
+ def test_generate_replace_sql(self):
+ values = [
+ {"id": "id", "name": "Stallone", "firstname": "Sylvester", "age":
"78"},
+ {"id": "id", "name": "Statham", "firstname": "Jason", "age": "57"},
+ {"id": "id", "name": "Li", "firstname": "Jet", "age": "61"},
+ {"id": "id", "name": "Lundgren", "firstname": "Dolph", "age":
"66"},
+ {"id": "id", "name": "Norris", "firstname": "Chuck", "age": "84"},
+ ]
+ target_fields = ["id", "name", "firstname", "age"]
+ sql =
MsSqlDialect(self.test_db_hook).generate_replace_sql("hollywood.actors",
values, target_fields)
+ assert (
+ sql
+ == """
+ MERGE INTO hollywood.actors WITH (ROWLOCK) AS target
+ USING (SELECT ? AS id, ? AS name, ? AS firstname, ? AS age) AS
source
+ ON target.id = source.id
+ WHEN MATCHED THEN
+ UPDATE SET target.name = source.name, target.firstname =
source.firstname, target.age = source.age
+ WHEN NOT MATCHED THEN
+ INSERT (id, name, firstname, age) VALUES (source.id,
source.name, source.firstname, source.age);
+ """.strip()
+ )
diff --git a/providers/tests/microsoft/mssql/hooks/test_mssql.py
b/providers/tests/microsoft/mssql/hooks/test_mssql.py
index be8f921112a..7153edde852 100644
--- a/providers/tests/microsoft/mssql/hooks/test_mssql.py
+++ b/providers/tests/microsoft/mssql/hooks/test_mssql.py
@@ -20,8 +20,11 @@ from __future__ import annotations
from unittest import mock
import pytest
+import sqlalchemy
+from airflow.configuration import conf
from airflow.models import Connection
+from airflow.providers.microsoft.mssql.dialects.mssql import MsSqlDialect
from providers.tests.microsoft.conftest import load_file
@@ -30,9 +33,72 @@ try:
except ImportError:
pytest.skip("MSSQL not available", allow_module_level=True)
+PYMSSQL_CONN = Connection(
+ conn_type="mssql", host="ip", schema="share", login="username",
password="password", port=8081
+)
+PYMSSQL_CONN_ALT = Connection(
+ conn_type="mssql", host="ip", schema="", login="username",
password="password", port=8081
+)
+PYMSSQL_CONN_ALT_1 = Connection(
+ conn_type="mssql",
+ host="ip",
+ schema="",
+ login="username",
+ password="password",
+ port=8081,
+ extra={"SQlalchemy_Scheme": "mssql+testdriver"},
+)
+PYMSSQL_CONN_ALT_2 = Connection(
+ conn_type="mssql",
+ host="ip",
+ schema="",
+ login="username",
+ password="password",
+ port=8081,
+ extra={"SQlalchemy_Scheme": "mssql+testdriver", "myparam": "5@-//*"},
+)
+
+
+def get_target_fields(self, table: str) -> list[str] | None:
+ return [
+ "ReportRefreshDate",
+ "UserId",
+ "UserPrincipalName",
+ "LastActivityDate",
+ "IsDeleted",
+ "DeletedDate",
+ "AssignedProducts",
+ "TeamChatMessageCount",
+ "PrivateChatMessageCount",
+ "CallCount",
+ "MeetingCount",
+ "MeetingsOrganizedCount",
+ "MeetingsAttendedCount",
+ "AdHocMeetingsOrganizedCount",
+ "AdHocMeetingsAttendedCount",
+ "ScheduledOne-timeMeetingsOrganizedCount",
+ "ScheduledOne-timeMeetingsAttendedCount",
+ "ScheduledRecurringMeetingsOrganizedCount",
+ "ScheduledRecurringMeetingsAttendedCount",
+ "AudioDuration",
+ "VideoDuration",
+ "ScreenShareDuration",
+ "AudioDurationInSeconds",
+ "VideoDurationInSeconds",
+ "ScreenShareDurationInSeconds",
+ "HasOtherAction",
+ "UrgentMessages",
+ "PostMessages",
+ "TenantDisplayName",
+ "SharedChannelTenantDisplayNames",
+ "ReplyMessages",
+ "IsLicensed",
+ "ReportPeriod",
+ "LoadDate",
+ ]
[email protected]
-def get_primary_keys():
+
+def get_primary_keys(self, table: str) -> list[str] | None:
return [
"GroupDisplayName",
"OwnerPrincipalName",
@@ -80,6 +146,14 @@ URI_TEST_CASES = [
class TestMsSqlHook:
+ def setup_method(self):
+ MsSqlHook._resolve_target_fields = True
+
+ def teardown_method(self, method):
+ MsSqlHook._resolve_target_fields = conf.getboolean(
+ "core", "dbapihook_resolve_target_fields", fallback=False
+ )
+
@mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_conn")
@mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook.get_connection")
def test_get_conn_should_return_connection(self, get_connection,
mssql_get_conn, mssql_connections):
@@ -161,88 +235,71 @@ class TestMsSqlHook:
hook.get_sqlalchemy_engine()
@mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_connection")
- def test_generate_insert_sql(self, get_connection, mssql_connections,
get_primary_keys):
- get_connection.return_value = mssql_connections["default"]
+ @mock.patch(
+
"airflow.providers.microsoft.mssql.dialects.mssql.MsSqlDialect.get_target_fields",
+ get_target_fields,
+ )
+ @mock.patch(
+
"airflow.providers.microsoft.mssql.dialects.mssql.MsSqlDialect.get_primary_keys",
+ get_primary_keys,
+ )
+ def test_generate_insert_sql(self, get_connection):
+ get_connection.return_value = PYMSSQL_CONN
+
+ hook = MsSqlHook()
+ sql = hook._generate_insert_sql(
+ table="YAMMER_GROUPS_ACTIVITY_DETAIL",
+ values=[
+ "2024-07-17",
+ "daa5b44c-80d6-4e22-85b5-a94e04cf7206",
+ "[email protected]",
+ "2024-07-17",
+ 0,
+ 0.0,
+ "MICROSOFT FABRIC (FREE)+MICROSOFT 365 E5",
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ "PT0S",
+ "PT0S",
+ "PT0S",
+ 0,
+ 0,
+ 0,
+ "Yes",
+ 0,
+ 0,
+ "APACHE",
+ 0.0,
+ 0,
+ "Yes",
+ 1,
+ "2024-07-17T00:00:00+00:00",
+ ],
+ replace=True,
+ )
+ assert sql == load_file("resources", "replace.sql")
+
+ def test_dialect_name(self):
+ hook = MsSqlHook()
+ assert hook.dialect_name == "mssql"
+
+ def test_dialect(self):
+ hook = MsSqlHook()
+ assert isinstance(hook.dialect, MsSqlDialect)
+ def test_reserved_words(self):
hook = MsSqlHook()
- with mock.patch.object(hook, "get_primary_keys",
return_value=get_primary_keys):
- sql = hook._generate_insert_sql(
- table="YAMMER_GROUPS_ACTIVITY_DETAIL",
- values=[
- "2024-07-17",
- "daa5b44c-80d6-4e22-85b5-a94e04cf7206",
- "[email protected]",
- "2024-07-17",
- 0,
- 0.0,
- "MICROSOFT FABRIC (FREE)+MICROSOFT 365 E5",
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- "PT0S",
- "PT0S",
- "PT0S",
- 0,
- 0,
- 0,
- "Yes",
- 0,
- 0,
- "APACHE",
- 0.0,
- 0,
- "Yes",
- 1,
- "2024-07-17T00:00:00+00:00",
- ],
- target_fields=[
- "ReportRefreshDate",
- "UserId",
- "UserPrincipalName",
- "LastActivityDate",
- "IsDeleted",
- "DeletedDate",
- "AssignedProducts",
- "TeamChatMessageCount",
- "PrivateChatMessageCount",
- "CallCount",
- "MeetingCount",
- "MeetingsOrganizedCount",
- "MeetingsAttendedCount",
- "AdHocMeetingsOrganizedCount",
- "AdHocMeetingsAttendedCount",
- "ScheduledOne-timeMeetingsOrganizedCount",
- "ScheduledOne-timeMeetingsAttendedCount",
- "ScheduledRecurringMeetingsOrganizedCount",
- "ScheduledRecurringMeetingsAttendedCount",
- "AudioDuration",
- "VideoDuration",
- "ScreenShareDuration",
- "AudioDurationInSeconds",
- "VideoDurationInSeconds",
- "ScreenShareDurationInSeconds",
- "HasOtherAction",
- "UrgentMessages",
- "PostMessages",
- "TenantDisplayName",
- "SharedChannelTenantDisplayNames",
- "ReplyMessages",
- "IsLicensed",
- "ReportPeriod",
- "LoadDate",
- ],
- replace=True,
- )
- assert sql == load_file("resources", "replace.sql")
+ assert hook.reserved_words ==
sqlalchemy.dialects.mssql.base.RESERVED_WORDS
@pytest.mark.db_test
@mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_connection")
diff --git a/providers/tests/microsoft/mssql/resources/replace.sql
b/providers/tests/microsoft/mssql/resources/replace.sql
index f8fb93b382e..07c7ec29e01 100644
--- a/providers/tests/microsoft/mssql/resources/replace.sql
+++ b/providers/tests/microsoft/mssql/resources/replace.sql
@@ -17,10 +17,10 @@
under the License.
*/
-MERGE INTO YAMMER_GROUPS_ACTIVITY_DETAIL AS target
- USING (SELECT %s AS ReportRefreshDate, %s AS UserId, %s AS
UserPrincipalName, %s AS LastActivityDate, %s AS IsDeleted, %s AS DeletedDate,
%s AS AssignedProducts, %s AS TeamChatMessageCount, %s AS
PrivateChatMessageCount, %s AS CallCount, %s AS MeetingCount, %s AS
MeetingsOrganizedCount, %s AS MeetingsAttendedCount, %s AS
AdHocMeetingsOrganizedCount, %s AS AdHocMeetingsAttendedCount, %s AS
ScheduledOne-timeMeetingsOrganizedCount, %s AS
ScheduledOne-timeMeetingsAttendedCount, %s AS [...]
- ON target.GroupDisplayName = source.GroupDisplayName AND
target.OwnerPrincipalName = source.OwnerPrincipalName AND target.ReportPeriod =
source.ReportPeriod AND target.ReportRefreshDate = source.ReportRefreshDate
- WHEN MATCHED THEN
- UPDATE SET target.UserId = source.UserId, target.UserPrincipalName
= source.UserPrincipalName, target.LastActivityDate = source.LastActivityDate,
target.IsDeleted = source.IsDeleted, target.DeletedDate = source.DeletedDate,
target.AssignedProducts = source.AssignedProducts, target.TeamChatMessageCount
= source.TeamChatMessageCount, target.PrivateChatMessageCount =
source.PrivateChatMessageCount, target.CallCount = source.CallCount,
target.MeetingCount = source.MeetingCount, t [...]
- WHEN NOT MATCHED THEN
- INSERT (ReportRefreshDate, UserId, UserPrincipalName,
LastActivityDate, IsDeleted, DeletedDate, AssignedProducts,
TeamChatMessageCount, PrivateChatMessageCount, CallCount, MeetingCount,
MeetingsOrganizedCount, MeetingsAttendedCount, AdHocMeetingsOrganizedCount,
AdHocMeetingsAttendedCount, ScheduledOne-timeMeetingsOrganizedCount,
ScheduledOne-timeMeetingsAttendedCount,
ScheduledRecurringMeetingsOrganizedCount,
ScheduledRecurringMeetingsAttendedCount, AudioDuration, VideoDurati [...]
+MERGE INTO YAMMER_GROUPS_ACTIVITY_DETAIL WITH (ROWLOCK) AS target
+ USING (SELECT %s AS ReportRefreshDate, %s AS UserId, %s AS
UserPrincipalName, %s AS LastActivityDate, %s AS IsDeleted, %s AS DeletedDate,
%s AS AssignedProducts, %s AS TeamChatMessageCount, %s AS
PrivateChatMessageCount, %s AS CallCount, %s AS MeetingCount, %s AS
MeetingsOrganizedCount, %s AS MeetingsAttendedCount, %s AS
AdHocMeetingsOrganizedCount, %s AS AdHocMeetingsAttendedCount, %s AS
ScheduledOne-timeMeetingsOrganizedCount, %s AS
ScheduledOne-timeMeetingsAttendedCount, % [...]
+ ON target.GroupDisplayName = source.GroupDisplayName AND
target.OwnerPrincipalName = source.OwnerPrincipalName AND target.ReportPeriod =
source.ReportPeriod AND target.ReportRefreshDate = source.ReportRefreshDate
+ WHEN MATCHED THEN
+ UPDATE SET target.UserId = source.UserId,
target.UserPrincipalName = source.UserPrincipalName, target.LastActivityDate =
source.LastActivityDate, target.IsDeleted = source.IsDeleted,
target.DeletedDate = source.DeletedDate, target.AssignedProducts =
source.AssignedProducts, target.TeamChatMessageCount =
source.TeamChatMessageCount, target.PrivateChatMessageCount =
source.PrivateChatMessageCount, target.CallCount = source.CallCount,
target.MeetingCount = source.MeetingCoun [...]
+ WHEN NOT MATCHED THEN
+ INSERT (ReportRefreshDate, UserId, UserPrincipalName,
LastActivityDate, IsDeleted, DeletedDate, AssignedProducts,
TeamChatMessageCount, PrivateChatMessageCount, CallCount, MeetingCount,
MeetingsOrganizedCount, MeetingsAttendedCount, AdHocMeetingsOrganizedCount,
AdHocMeetingsAttendedCount, ScheduledOne-timeMeetingsOrganizedCount,
ScheduledOne-timeMeetingsAttendedCount,
ScheduledRecurringMeetingsOrganizedCount,
ScheduledRecurringMeetingsAttendedCount, AudioDuration, VideoDu [...]
diff --git a/providers/tests/mysql/hooks/test_mysql.py
b/providers/tests/mysql/hooks/test_mysql.py
index 0e21047b107..6facc6f4b44 100644
--- a/providers/tests/mysql/hooks/test_mysql.py
+++ b/providers/tests/mysql/hooks/test_mysql.py
@@ -24,6 +24,7 @@ from contextlib import closing
from unittest import mock
import pytest
+import sqlalchemy
from airflow.models import Connection
from airflow.models.dag import DAG
@@ -31,18 +32,20 @@ from airflow.models.dag import DAG
try:
import MySQLdb.cursors
- from airflow.providers.mysql.hooks.mysql import MySqlHook
+ MYSQL_AVAILABLE = True
except ImportError:
- pytest.skip("MySQL not available", allow_module_level=True)
-
+ MYSQL_AVAILABLE = False
+from airflow.providers.mysql.hooks.mysql import MySqlHook
from airflow.utils import timezone
from tests_common.test_utils.asserts import assert_equal_ignore_multiple_spaces
SSL_DICT = {"cert": "/tmp/client-cert.pem", "ca": "/tmp/server-ca.pem", "key":
"/tmp/client-key.pem"}
+INSERT_SQL_STATEMENT = "INSERT INTO connection (id, conn_id, conn_type,
description, host, `schema`, login, password, port, is_encrypted,
is_extra_encrypted, extra) VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)"
[email protected](not MYSQL_AVAILABLE, reason="MySQL not available")
class TestMySqlHookConn:
def setup_method(self):
self.connection = Connection(
@@ -223,6 +226,7 @@ class MockMySQLConnectorConnection:
self._autocommit = autocommit
[email protected]_test
class TestMySqlHook:
def setup_method(self):
self.cur = mock.MagicMock(rowcount=0)
@@ -327,6 +331,80 @@ class TestMySqlHook:
),
)
+ def test_reserved_words(self):
+ hook = MySqlHook()
+ assert hook.reserved_words ==
sqlalchemy.dialects.mysql.reserved_words.RESERVED_WORDS_MYSQL
+
+ def test_generate_insert_sql_without_already_escaped_column_name(self):
+ values = [
+ "1",
+ "mssql_conn",
+ "mssql",
+ "MSSQL connection",
+ "localhost",
+ "airflow",
+ "admin",
+ "admin",
+ 1433,
+ False,
+ False,
+ {},
+ ]
+ target_fields = [
+ "id",
+ "conn_id",
+ "conn_type",
+ "description",
+ "host",
+ "schema",
+ "login",
+ "password",
+ "port",
+ "is_encrypted",
+ "is_extra_encrypted",
+ "extra",
+ ]
+ hook = MySqlHook()
+ assert (
+ hook._generate_insert_sql(table="connection", values=values,
target_fields=target_fields)
+ == INSERT_SQL_STATEMENT
+ )
+
+ def test_generate_insert_sql_with_already_escaped_column_name(self):
+ values = [
+ "1",
+ "mssql_conn",
+ "mssql",
+ "MSSQL connection",
+ "localhost",
+ "airflow",
+ "admin",
+ "admin",
+ 1433,
+ False,
+ False,
+ {},
+ ]
+ target_fields = [
+ "id",
+ "conn_id",
+ "conn_type",
+ "description",
+ "host",
+ "`schema`",
+ "login",
+ "password",
+ "port",
+ "is_encrypted",
+ "is_extra_encrypted",
+ "extra",
+ ]
+ hook = MySqlHook()
+ assert (
+ hook._generate_insert_sql(table="connection", values=values,
target_fields=target_fields)
+ == INSERT_SQL_STATEMENT
+ )
+
DEFAULT_DATE = timezone.datetime(2015, 1, 1)
DEFAULT_DATE_ISO = DEFAULT_DATE.isoformat()
@@ -348,6 +426,7 @@ class MySqlContext:
@pytest.mark.backend("mysql")
[email protected](not MYSQL_AVAILABLE, reason="MySQL not available")
class TestMySql:
def setup_method(self):
args = {"owner": "airflow", "start_date": DEFAULT_DATE}
diff --git a/providers/tests/odbc/hooks/test_odbc.py
b/providers/tests/odbc/hooks/test_odbc.py
index 5d2e195dcc6..038bb4e1c4f 100644
--- a/providers/tests/odbc/hooks/test_odbc.py
+++ b/providers/tests/odbc/hooks/test_odbc.py
@@ -27,6 +27,7 @@ from urllib.parse import quote_plus, urlsplit
import pyodbc
import pytest
+from sqlalchemy.exc import ArgumentError
from airflow.providers.odbc.hooks.odbc import OdbcHook
@@ -78,6 +79,10 @@ def pyodbc_instancecheck():
return PyodbcRow
+def raise_argument_error():
+ raise ArgumentError()
+
+
class TestOdbcHook:
def test_driver_in_extra_not_used(self):
conn_params = dict(extra=json.dumps(dict(Driver="Fake Driver",
Fake_Param="Fake Param")))
@@ -342,6 +347,26 @@ class TestOdbcHook:
result = hook.run("SQL")
assert result is None
+ def test_dialect_name_when_resolved_from_sqlalchemy_uri(self):
+ hook = mock_hook(OdbcHook)
+ assert hook.dialect_name == "mssql"
+
+ def test_dialect_name_when_resolved_from_conn_type(self):
+ hook = mock_hook(OdbcHook)
+ hook.get_conn().conn_type = "sqlite"
+ hook.get_uri = raise_argument_error
+ assert hook.dialect_name == "default"
+
+ def test_dialect_name_when_resolved_from_sqlalchemy_scheme_in_extra(self):
+ hook = mock_hook(OdbcHook, conn_params={"extra": {"sqlalchemy_scheme":
"mssql+pymssql"}})
+ hook.get_uri = raise_argument_error
+ assert hook.dialect_name == "mssql"
+
+ def test_dialect_name_when_resolved_from_dialect_in_extra(self):
+ hook = mock_hook(OdbcHook, conn_params={"extra": {"dialect":
"oracle"}})
+ hook.get_uri = raise_argument_error
+ assert hook.dialect_name == "oracle"
+
def test_get_sqlalchemy_engine_verify_creator_is_being_used(self):
hook = mock_hook(OdbcHook, conn_params={"extra": {"sqlalchemy_scheme":
"sqlite"}})
diff --git a/providers/tests/postgres/dialects/__init__.py
b/providers/tests/postgres/dialects/__init__.py
new file mode 100644
index 00000000000..13a83393a91
--- /dev/null
+++ b/providers/tests/postgres/dialects/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/providers/tests/postgres/dialects/test_postgres.py
b/providers/tests/postgres/dialects/test_postgres.py
new file mode 100644
index 00000000000..ab4968a6645
--- /dev/null
+++ b/providers/tests/postgres/dialects/test_postgres.py
@@ -0,0 +1,87 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import MagicMock
+
+from sqlalchemy.engine import Inspector
+
+from airflow.providers.common.sql.hooks.sql import DbApiHook
+from airflow.providers.postgres.dialects.postgres import PostgresDialect
+
+
+class TestPostgresDialect:
+ def setup_method(self):
+ inspector = MagicMock(spc=Inspector)
+ inspector.get_columns.side_effect = lambda table_name, schema: [
+ {"name": "id", "identity": True},
+ {"name": "name"},
+ {"name": "firstname"},
+ {"name": "age"},
+ ]
+
+ def get_records(sql, parameters):
+ assert isinstance(sql, str)
+ assert "hollywood" in parameters, "Missing 'schema' in parameters"
+ assert "actors" in parameters, "Missing 'table' in parameters"
+ return [("id",)]
+
+ self.test_db_hook = MagicMock(placeholder="?", inspector=inspector,
spec=DbApiHook)
+ self.test_db_hook.get_records.side_effect = get_records
+ self.test_db_hook._insert_statement_format = "INSERT INTO {} {} VALUES
({})"
+ self.test_db_hook._escape_column_name_format = '"{}"'
+
+ def test_placeholder(self):
+ assert PostgresDialect(self.test_db_hook).placeholder == "?"
+
+ def test_get_column_names(self):
+ assert
PostgresDialect(self.test_db_hook).get_column_names("hollywood.actors") == [
+ "id",
+ "name",
+ "firstname",
+ "age",
+ ]
+
+ def test_get_target_fields(self):
+ assert
PostgresDialect(self.test_db_hook).get_target_fields("hollywood.actors") == [
+ "name",
+ "firstname",
+ "age",
+ ]
+
+ def test_get_primary_keys(self):
+ assert
PostgresDialect(self.test_db_hook).get_primary_keys("hollywood.actors") ==
["id"]
+
+ def test_generate_replace_sql(self):
+ values = [
+ {"id": "id", "name": "Stallone", "firstname": "Sylvester", "age":
"78"},
+ {"id": "id", "name": "Statham", "firstname": "Jason", "age": "57"},
+ {"id": "id", "name": "Li", "firstname": "Jet", "age": "61"},
+ {"id": "id", "name": "Lundgren", "firstname": "Dolph", "age":
"66"},
+ {"id": "id", "name": "Norris", "firstname": "Chuck", "age": "84"},
+ ]
+ target_fields = ["id", "name", "firstname", "age"]
+ sql = PostgresDialect(self.test_db_hook).generate_replace_sql(
+ "hollywood.actors", values, target_fields
+ )
+ assert (
+ sql
+ == """
+ INSERT INTO hollywood.actors (id, name, firstname, age) VALUES
(?,?,?,?,?) ON CONFLICT (id) DO UPDATE SET name = excluded.name, firstname =
excluded.firstname, age = excluded.age
+ """.strip()
+ )
diff --git a/providers/tests/postgres/hooks/test_postgres.py
b/providers/tests/postgres/hooks/test_postgres.py
index 76206d57958..2483dba9132 100644
--- a/providers/tests/postgres/hooks/test_postgres.py
+++ b/providers/tests/postgres/hooks/test_postgres.py
@@ -24,12 +24,16 @@ from unittest import mock
import psycopg2.extras
import pytest
+import sqlalchemy
from airflow.exceptions import AirflowException
from airflow.models import Connection
+from airflow.providers.postgres.dialects.postgres import PostgresDialect
from airflow.providers.postgres.hooks.postgres import PostgresHook
from airflow.utils.types import NOTSET
+INSERT_SQL_STATEMENT = "INSERT INTO connection (id, conn_id, conn_type,
description, host, {}, login, password, port, is_encrypted, is_extra_encrypted,
extra) VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)"
+
class TestPostgresHookConn:
def setup_method(self):
@@ -645,3 +649,81 @@ class TestPostgresHook:
assert "NOTICE: Message from db: 42" in caplog.text
finally:
hook.run(sql=f"DROP PROCEDURE {proc_name} (s text)")
+
+ def test_dialect_name(self):
+ assert self.db_hook.dialect_name == "postgresql"
+
+ def test_dialect(self):
+ assert isinstance(self.db_hook.dialect, PostgresDialect)
+
+ def test_reserved_words(self):
+ hook = PostgresHook()
+ assert hook.reserved_words ==
sqlalchemy.dialects.postgresql.base.RESERVED_WORDS
+
+ def test_generate_insert_sql_without_already_escaped_column_name(self):
+ values = [
+ "1",
+ "mssql_conn",
+ "mssql",
+ "MSSQL connection",
+ "localhost",
+ "airflow",
+ "admin",
+ "admin",
+ 1433,
+ False,
+ False,
+ {},
+ ]
+ target_fields = [
+ "id",
+ "conn_id",
+ "conn_type",
+ "description",
+ "host",
+ "schema",
+ "login",
+ "password",
+ "port",
+ "is_encrypted",
+ "is_extra_encrypted",
+ "extra",
+ ]
+ hook = PostgresHook()
+ assert hook._generate_insert_sql(
+ table="connection", values=values, target_fields=target_fields
+ ) == INSERT_SQL_STATEMENT.format("schema")
+
+ def test_generate_insert_sql_with_already_escaped_column_name(self):
+ values = [
+ "1",
+ "mssql_conn",
+ "mssql",
+ "MSSQL connection",
+ "localhost",
+ "airflow",
+ "admin",
+ "admin",
+ 1433,
+ False,
+ False,
+ {},
+ ]
+ target_fields = [
+ "id",
+ "conn_id",
+ "conn_type",
+ "description",
+ "host",
+ '"schema"',
+ "login",
+ "password",
+ "port",
+ "is_encrypted",
+ "is_extra_encrypted",
+ "extra",
+ ]
+ hook = PostgresHook()
+ assert hook._generate_insert_sql(
+ table="connection", values=values, target_fields=target_fields
+ ) == INSERT_SQL_STATEMENT.format('"schema"')
diff --git a/providers/tests/teradata/hooks/test_teradata.py
b/providers/tests/teradata/hooks/test_teradata.py
index f2bcb1ee269..10754555b1a 100644
--- a/providers/tests/teradata/hooks/test_teradata.py
+++ b/providers/tests/teradata/hooks/test_teradata.py
@@ -28,6 +28,7 @@ from airflow.providers.teradata.hooks.teradata import
TeradataHook, _handle_user
class TestTeradataHook:
def setup_method(self):
self.connection = Connection(
+ conn_id="teradata_conn_id",
conn_type="teradata",
login="login",
password="password",
@@ -43,12 +44,14 @@ class TestTeradataHook:
conn = self.conn
class UnitTestTeradataHook(TeradataHook):
- conn_name_attr = "teradata_conn_id"
-
def get_conn(self):
return conn
- self.test_db_hook = UnitTestTeradataHook()
+ @classmethod
+ def get_connection(cls, conn_id: str) -> Connection:
+ return conn
+
+ self.test_db_hook =
UnitTestTeradataHook(teradata_conn_id="teradata_conn_id")
@mock.patch("teradatasql.connect")
def test_get_conn(self, mock_connect):
diff --git a/tests/always/test_providers_manager.py
b/tests/always/test_providers_manager.py
index 4af0c729713..a808aedbb80 100644
--- a/tests/always/test_providers_manager.py
+++ b/tests/always/test_providers_manager.py
@@ -446,7 +446,8 @@ class TestProviderManager:
def test_dialects(self):
provider_manager = ProvidersManager()
dialect_class_names = list(provider_manager.dialects)
- assert len(dialect_class_names) == 0
+ assert len(dialect_class_names) == 3
+ assert dialect_class_names == ["default", "mssql", "postgresql"]
@patch("airflow.providers_manager.import_string")
def test_optional_feature_no_warning(self, mock_importlib_import_string):