This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f6c7388cfa Create SQLAlchemy engine from connection in DB Hook and
added autocommit param to insert_rows method (#40669)
f6c7388cfa is described below
commit f6c7388cfa70874d84f312a5859a4f510fef0084
Author: David Blain <[email protected]>
AuthorDate: Fri Jul 26 21:33:27 2024 +0200
Create SQLAlchemy engine from connection in DB Hook and added autocommit
param to insert_rows method (#40669)
* refactor: Refactored get_sqlalchemy_engine method of DbApiHook to use the
get_conn result to build the sqlalchemy engine
* refactor: Added autocommit parameter to insert_rows just like with the
run method as this parameter will also be needed once whe have the
SQLInsertRowsOperator
* refactor: Updated the docstring of the insert_rows method
* refactor: Updated sql.pyi
* refactor: Try to fix AttributeError: type object 'SkipDBTestsSession' has
no attribute 'get_bind'
* refactor: Implemented the sqlalchemy_url property for JdbcHook
* refactor: Refactored get_sqlalchemy_engine in DbApiHook, if Hook
implements the sqlalchemy_url property then use it, otherwise fallback to
original implementation with get_uri
* refactor: Added SQLAlchemy Inspector property in DbApiHook
* refactor: Reformated test_sqlalchemy_url_with_sqlalchemy_scheme in
TestJdbcHook
* refactor: Fixed static checks in DbApiHook
* refactor: Fixed some static checks
* docs: Updated docstring of JdbcHook and mentioned importance of
sqlalchemy_scheme parameter
---------
Co-authored-by: David Blain <[email protected]>
---
airflow/providers/common/sql/hooks/sql.py | 23 ++++++++++++++++++++---
airflow/providers/common/sql/hooks/sql.pyi | 6 ++++--
airflow/providers/jdbc/hooks/jdbc.py | 26 +++++++++++++++++++++++++-
airflow/settings.py | 10 ++++++++++
tests/providers/jdbc/hooks/test_jdbc.py | 15 +++++++++++++++
5 files changed, 74 insertions(+), 6 deletions(-)
diff --git a/airflow/providers/common/sql/hooks/sql.py
b/airflow/providers/common/sql/hooks/sql.py
index 4dba2d843a..dc66cc40bd 100644
--- a/airflow/providers/common/sql/hooks/sql.py
+++ b/airflow/providers/common/sql/hooks/sql.py
@@ -40,6 +40,7 @@ from urllib.parse import urlparse
import sqlparse
from more_itertools import chunked
from sqlalchemy import create_engine
+from sqlalchemy.engine import Inspector
from airflow.exceptions import (
AirflowException,
@@ -242,7 +243,20 @@ class DbApiHook(BaseHook):
"""
if engine_kwargs is None:
engine_kwargs = {}
- return create_engine(self.get_uri(), **engine_kwargs)
+ engine_kwargs["creator"] = self.get_conn
+
+ try:
+ url = self.sqlalchemy_url
+ except NotImplementedError:
+ url = self.get_uri()
+
+ self.log.debug("url: %s", url)
+ self.log.debug("engine_kwargs: %s", engine_kwargs)
+ return create_engine(url=url, **engine_kwargs)
+
+ @property
+ def inspector(self) -> Inspector:
+ return Inspector.from_engine(self.get_sqlalchemy_engine())
def get_pandas_df(
self,
@@ -571,6 +585,7 @@ class DbApiHook(BaseHook):
replace=False,
*,
executemany=False,
+ autocommit=False,
**kwargs,
):
"""
@@ -585,12 +600,14 @@ class DbApiHook(BaseHook):
:param commit_every: The maximum number of rows to insert in one
transaction. Set to 0 to insert all rows in one transaction.
:param replace: Whether to replace instead of insert
- :param executemany: (Deprecated) If True, all rows are inserted at
once in
+ :param executemany: If True, all rows are inserted at once in
chunks defined by the commit_every parameter. This only works if
all rows
have same number of column names, but leads to better performance.
+ :param autocommit: What to set the connection's autocommit setting to
+ before executing the query.
"""
nb_rows = 0
- with self._create_autocommit_connection() as conn:
+ with self._create_autocommit_connection(autocommit) as conn:
conn.commit()
with closing(conn.cursor()) as cur:
if self.supports_executemany or executemany:
diff --git a/airflow/providers/common/sql/hooks/sql.pyi
b/airflow/providers/common/sql/hooks/sql.pyi
index 27142aeaf2..625ec1d320 100644
--- a/airflow/providers/common/sql/hooks/sql.pyi
+++ b/airflow/providers/common/sql/hooks/sql.pyi
@@ -42,7 +42,7 @@ from airflow.providers.openlineage.extractors import
OperatorLineage as Operator
from airflow.providers.openlineage.sqlparser import DatabaseInfo as
DatabaseInfo
from functools import cached_property as cached_property
from pandas import DataFrame as DataFrame
-from sqlalchemy.engine import URL as URL
+from sqlalchemy.engine import Inspector, URL as URL
from typing import Any, Callable, Generator, Iterable, Mapping, Protocol,
Sequence, TypeVar, overload
T = TypeVar("T")
@@ -64,7 +64,6 @@ class DbApiHook(BaseHook):
log_sql: Incomplete
descriptions: Incomplete
def __init__(self, *args, schema: str | None = None, log_sql: bool = True,
**kwargs) -> None: ...
-
def get_conn_id(self) -> str: ...
@cached_property
def placeholder(self): ...
@@ -73,6 +72,8 @@ class DbApiHook(BaseHook):
@property
def sqlalchemy_url(self) -> URL: ...
def get_sqlalchemy_engine(self, engine_kwargs: Incomplete | None = None):
...
+ @property
+ def inspector(self) -> Inspector: ...
def get_pandas_df(
self, sql, parameters: list | tuple | Mapping[str, Any] | None = None,
**kwargs
) -> DataFrame: ...
@@ -123,6 +124,7 @@ class DbApiHook(BaseHook):
replace: bool = False,
*,
executemany: bool = False,
+ autocommit: bool = False,
**kwargs,
): ...
def bulk_dump(self, table, tmp_file) -> None: ...
diff --git a/airflow/providers/jdbc/hooks/jdbc.py
b/airflow/providers/jdbc/hooks/jdbc.py
index cf5d2dd47d..81c63cbe3d 100644
--- a/airflow/providers/jdbc/hooks/jdbc.py
+++ b/airflow/providers/jdbc/hooks/jdbc.py
@@ -23,7 +23,9 @@ from contextlib import contextmanager
from typing import TYPE_CHECKING, Any
import jaydebeapi
+from sqlalchemy.engine import URL
+from airflow.exceptions import AirflowException
from airflow.providers.common.sql.hooks.sql import DbApiHook
if TYPE_CHECKING:
@@ -60,7 +62,12 @@ class JdbcHook(DbApiHook):
"providers.jdbc" section of the Airflow configuration. If you're
enabling these options in Airflow
configuration, you should make sure that you trust the users who
can edit connections in the UI
to not use it maliciously.
- 4. Patch the ``JdbcHook.default_driver_path`` and/or
``JdbcHook.default_driver_class`` values in the
+ 4. Define the "sqlalchemy_scheme" property in the extra of the
connection if you want to use the
+ SQLAlchemy engine from the JdbcHook. When using the JdbcHook, the
"sqlalchemy_scheme" will by
+ default have the "jdbc" value, which is a protocol, not a database
scheme or dialect. So in order
+ to be able to use SQLAlchemy with the JdbcHook, you need to define
the "sqlalchemy_scheme"
+ property in the extra of the connection.
+ 5. Patch the ``JdbcHook.default_driver_path`` and/or
``JdbcHook.default_driver_class`` values in the
``local_settings.py`` file.
See :doc:`/connections/jdbc` for full documentation.
@@ -149,6 +156,23 @@ class JdbcHook(DbApiHook):
self._driver_class = self.default_driver_class
return self._driver_class
+ @property
+ def sqlalchemy_url(self) -> URL:
+ conn = self.get_connection(getattr(self, self.conn_name_attr))
+ sqlalchemy_scheme = conn.extra_dejson.get("sqlalchemy_scheme")
+ if sqlalchemy_scheme is None:
+ raise AirflowException(
+ "The parameter 'sqlalchemy_scheme' must be defined in extra
for JDBC connections!"
+ )
+ return URL.create(
+ drivername=sqlalchemy_scheme,
+ username=conn.login,
+ password=conn.password,
+ host=conn.host,
+ port=conn.port,
+ database=conn.schema,
+ )
+
def get_conn(self) -> jaydebeapi.Connection:
conn: Connection = self.get_connection(self.get_conn_id())
host: str = conn.host
diff --git a/airflow/settings.py b/airflow/settings.py
index 6dc9880271..eb4053f50e 100644
--- a/airflow/settings.py
+++ b/airflow/settings.py
@@ -256,6 +256,16 @@ class SkipDBTestsSession:
def remove(*args, **kwargs):
pass
+ def get_bind(
+ self,
+ mapper=None,
+ clause=None,
+ bind=None,
+ _sa_skip_events=None,
+ _sa_skip_for_implicit_returning=False,
+ ):
+ pass
+
class TracebackSession:
"""
diff --git a/tests/providers/jdbc/hooks/test_jdbc.py
b/tests/providers/jdbc/hooks/test_jdbc.py
index 6e4387ee1a..cb38ce40ae 100644
--- a/tests/providers/jdbc/hooks/test_jdbc.py
+++ b/tests/providers/jdbc/hooks/test_jdbc.py
@@ -25,6 +25,7 @@ from unittest.mock import Mock, patch
import jaydebeapi
import pytest
+from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.providers.jdbc.hooks.jdbc import JdbcHook, suppress_and_warn
from airflow.utils import db
@@ -186,3 +187,17 @@ class TestJdbcHook:
with pytest.raises(RuntimeError, match="Spam Egg"):
with suppress_and_warn(KeyError):
raise RuntimeError("Spam Egg")
+
+ def test_sqlalchemy_url_without_sqlalchemy_scheme(self):
+ hook_params = {"driver_path": "ParamDriverPath", "driver_class":
"ParamDriverClass"}
+ hook = get_hook(hook_params=hook_params)
+
+ with pytest.raises(AirflowException):
+ hook.sqlalchemy_url
+
+ def test_sqlalchemy_url_with_sqlalchemy_scheme(self):
+ conn_params = dict(extra=json.dumps(dict(sqlalchemy_scheme="mssql")))
+ hook_params = {"driver_path": "ParamDriverPath", "driver_class":
"ParamDriverClass"}
+ hook = get_hook(conn_params=conn_params, hook_params=hook_params)
+
+ assert str(hook.sqlalchemy_url) ==
"mssql://login:password@host:1234/schema"