This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 7ab24c7723 Always use the executemany method when inserting rows in
DbApiHook as it's way much faster (#38715)
7ab24c7723 is described below
commit 7ab24c7723c65c90626b10db63444b88c0380e14
Author: David Blain <[email protected]>
AuthorDate: Fri Apr 12 11:43:23 2024 +0200
Always use the executemany method when inserting rows in DbApiHook as it's
way much faster (#38715)
---------
Co-authored-by: David Blain <[email protected]>
Co-authored-by: Tzu-ping Chung <[email protected]>
---
airflow/providers/common/sql/hooks/sql.py | 52 ++++++++++++++-----------
airflow/providers/common/sql/hooks/sql.pyi | 1 +
airflow/providers/odbc/hooks/odbc.py | 1 +
airflow/providers/postgres/hooks/postgres.py | 1 +
airflow/providers/teradata/hooks/teradata.py | 48 ++++++++---------------
tests/deprecations_ignore.yml | 6 +++
tests/providers/common/sql/hooks/test_dbapi.py | 3 +-
tests/providers/postgres/hooks/test_postgres.py | 9 ++---
tests/providers/teradata/hooks/test_teradata.py | 22 ++++++-----
9 files changed, 73 insertions(+), 70 deletions(-)
diff --git a/airflow/providers/common/sql/hooks/sql.py
b/airflow/providers/common/sql/hooks/sql.py
index 7f1536a39b..4625c2e014 100644
--- a/airflow/providers/common/sql/hooks/sql.py
+++ b/airflow/providers/common/sql/hooks/sql.py
@@ -18,7 +18,7 @@ from __future__ import annotations
import contextlib
import warnings
-from contextlib import closing
+from contextlib import closing, contextmanager
from datetime import datetime
from typing import (
TYPE_CHECKING,
@@ -147,6 +147,8 @@ class DbApiHook(BaseHook):
default_conn_name = "default_conn_id"
# Override if this db supports autocommit.
supports_autocommit = False
+ # Override if this db supports executemany.
+ supports_executemany = False
# Override with the object that exposes the connect method
connector: ConnectorProtocol | None = None
# Override with db-specific query to check connection
@@ -408,10 +410,7 @@ class DbApiHook(BaseHook):
else:
raise ValueError("List of SQL statements is empty")
_last_result = None
- with closing(self.get_conn()) as conn:
- if self.supports_autocommit:
- self.set_autocommit(conn, autocommit)
-
+ with self._create_autocommit_connection(autocommit) as conn:
with closing(conn.cursor()) as cur:
results = []
for sql_statement in sql_list:
@@ -528,6 +527,14 @@ class DbApiHook(BaseHook):
return self._replace_statement_format.format(table, target_fields,
",".join(placeholders))
+ @contextmanager
+ def _create_autocommit_connection(self, autocommit: bool = False):
+ """Context manager that closes the connection after use and detects if
autocommit is supported."""
+ with closing(self.get_conn()) as conn:
+ if self.supports_autocommit:
+ self.set_autocommit(conn, autocommit)
+ yield conn
+
def insert_rows(
self,
table,
@@ -550,47 +557,48 @@ class DbApiHook(BaseHook):
:param commit_every: The maximum number of rows to insert in one
transaction. Set to 0 to insert all rows in one transaction.
:param replace: Whether to replace instead of insert
- :param executemany: Insert all rows at once in chunks defined by the
commit_every parameter, only
- works if all rows have same number of column names but leads to
better performance
+ :param executemany: (Deprecated) If True, all rows are inserted at
once in
+ chunks defined by the commit_every parameter. This only works if
all rows
+ have same number of column names, but leads to better performance.
"""
- i = 0
- with closing(self.get_conn()) as conn:
- if self.supports_autocommit:
- self.set_autocommit(conn, False)
+ if executemany:
+ warnings.warn(
+ "executemany parameter is deprecated, override
supports_executemany instead.",
+ AirflowProviderDeprecationWarning,
+ stacklevel=2,
+ )
+ with self._create_autocommit_connection() as conn:
conn.commit()
-
with closing(conn.cursor()) as cur:
- if executemany:
+ if self.supports_executemany or executemany:
for chunked_rows in chunked(rows, commit_every):
values = list(
map(
- lambda row: tuple(map(lambda cell:
self._serialize_cell(cell, conn), row)),
+ lambda row: self._serialize_cells(row, conn),
chunked_rows,
)
)
sql = self._generate_insert_sql(table, values[0],
target_fields, replace, **kwargs)
self.log.debug("Generated sql: %s", sql)
- cur.fast_executemany = True
cur.executemany(sql, values)
conn.commit()
self.log.info("Loaded %s rows into %s so far",
len(chunked_rows), table)
else:
for i, row in enumerate(rows, 1):
- lst = []
- for cell in row:
- lst.append(self._serialize_cell(cell, conn))
- values = tuple(lst)
+ values = self._serialize_cells(row, conn)
sql = self._generate_insert_sql(table, values,
target_fields, replace, **kwargs)
self.log.debug("Generated sql: %s", sql)
cur.execute(sql, values)
if commit_every and i % commit_every == 0:
conn.commit()
self.log.info("Loaded %s rows into %s so far", i,
table)
+ conn.commit()
+ self.log.info("Done loading. Loaded a total of %s rows into %s",
len(rows), table)
- if not executemany:
- conn.commit()
- self.log.info("Done loading. Loaded a total of %s rows into %s", i,
table)
+ @classmethod
+ def _serialize_cells(cls, row, conn=None):
+ return tuple(cls._serialize_cell(cell, conn) for cell in row)
@staticmethod
def _serialize_cell(cell, conn=None) -> str | None:
diff --git a/airflow/providers/common/sql/hooks/sql.pyi
b/airflow/providers/common/sql/hooks/sql.pyi
index 83135a235b..85edd625f9 100644
--- a/airflow/providers/common/sql/hooks/sql.pyi
+++ b/airflow/providers/common/sql/hooks/sql.pyi
@@ -57,6 +57,7 @@ class DbApiHook(BaseHook):
conn_name_attr: str
default_conn_name: str
supports_autocommit: bool
+ supports_executemany: bool
connector: ConnectorProtocol | None
log_sql: Incomplete
descriptions: Incomplete
diff --git a/airflow/providers/odbc/hooks/odbc.py
b/airflow/providers/odbc/hooks/odbc.py
index 8cf95bf095..53c4cf207a 100644
--- a/airflow/providers/odbc/hooks/odbc.py
+++ b/airflow/providers/odbc/hooks/odbc.py
@@ -56,6 +56,7 @@ class OdbcHook(DbApiHook):
conn_type = "odbc"
hook_name = "ODBC"
supports_autocommit = True
+ supports_executemany = True
default_driver: str | None = None
diff --git a/airflow/providers/postgres/hooks/postgres.py
b/airflow/providers/postgres/hooks/postgres.py
index 0afb7740fe..9e1b3a83d7 100644
--- a/airflow/providers/postgres/hooks/postgres.py
+++ b/airflow/providers/postgres/hooks/postgres.py
@@ -74,6 +74,7 @@ class PostgresHook(DbApiHook):
conn_type = "postgres"
hook_name = "Postgres"
supports_autocommit = True
+ supports_executemany = True
def __init__(self, *args, options: str | None = None, **kwargs) -> None:
if "schema" in kwargs:
diff --git a/airflow/providers/teradata/hooks/teradata.py
b/airflow/providers/teradata/hooks/teradata.py
index 73c4fb8ff0..3afc32bc74 100644
--- a/airflow/providers/teradata/hooks/teradata.py
+++ b/airflow/providers/teradata/hooks/teradata.py
@@ -19,12 +19,14 @@
from __future__ import annotations
+import warnings
from typing import TYPE_CHECKING, Any
import sqlalchemy
import teradatasql
from teradatasql import TeradataConnection
+from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.common.sql.hooks.sql import DbApiHook
if TYPE_CHECKING:
@@ -59,6 +61,9 @@ class TeradataHook(DbApiHook):
# Override if this db supports autocommit.
supports_autocommit = True
+ # Override if this db supports executemany.
+ supports_executemany = True
+
# Override this for hook to have a custom name in the UI selection
conn_type = "teradata"
@@ -97,7 +102,9 @@ class TeradataHook(DbApiHook):
target_fields: list[str] | None = None,
commit_every: int = 5000,
):
- """Insert bulk of records into Teradata SQL Database.
+ """Use :func:`insert_rows` instead, this is deprecated.
+
+ Insert bulk of records into Teradata SQL Database.
This uses prepared statements via `executemany()`. For best
performance,
pass in `rows` as an iterator.
@@ -106,41 +113,20 @@ class TeradataHook(DbApiHook):
specific database
:param rows: the rows to insert into the table
:param target_fields: the names of the columns to fill in the table,
default None.
- If None, each rows should have some order as table columns name
+ If None, each row should have some order as table columns name
:param commit_every: the maximum number of rows to insert in one
transaction
Default 5000. Set greater than 0. Set 1 to insert each row in each
transaction
"""
+ warnings.warn(
+ "bulk_insert_rows is deprecated. Please use the insert_rows method
instead.",
+ AirflowProviderDeprecationWarning,
+ stacklevel=2,
+ )
+
if not rows:
raise ValueError("parameter rows could not be None or empty
iterable")
- conn = self.get_conn()
- if self.supports_autocommit:
- self.set_autocommit(conn, False)
- cursor = conn.cursor()
- cursor.fast_executemany = True
- values_base = target_fields if target_fields else rows[0]
- prepared_stm = "INSERT INTO {tablename} {columns} VALUES
({values})".format(
- tablename=table,
- columns="({})".format(", ".join(target_fields)) if target_fields
else "",
- values=", ".join("?" for i in range(1, len(values_base) + 1)),
- )
- row_count = 0
- # Chunk the rows
- row_chunk = []
- for row in rows:
- row_chunk.append(row)
- row_count += 1
- if row_count % commit_every == 0:
- cursor.executemany(prepared_stm, row_chunk)
- conn.commit() # type: ignore[attr-defined]
- # Empty chunk
- row_chunk = []
- # Commit the leftover chunk
- if len(row_chunk) > 0:
- cursor.executemany(prepared_stm, row_chunk)
- conn.commit() # type: ignore[attr-defined]
- self.log.info("[%s] inserted %s rows", table, row_count)
- cursor.close()
- conn.close() # type: ignore[attr-defined]
+
+ self.insert_rows(table=table, rows=rows, target_fields=target_fields,
commit_every=commit_every)
def _get_conn_config_teradatasql(self) -> dict[str, Any]:
"""Return set of config params required for connecting to Teradata DB
using teradatasql client."""
diff --git a/tests/deprecations_ignore.yml b/tests/deprecations_ignore.yml
index c88d2b9368..c8f02a193d 100644
--- a/tests/deprecations_ignore.yml
+++ b/tests/deprecations_ignore.yml
@@ -659,6 +659,8 @@
-
tests/providers/cncf/kubernetes/test_pod_generator.py::TestPodGenerator::test_pod_name_confirm_to_max_length
-
tests/providers/cncf/kubernetes/test_pod_generator.py::TestPodGenerator::test_pod_name_is_valid
-
tests/providers/cncf/kubernetes/test_template_rendering.py::test_render_k8s_pod_yaml
+-
tests/providers/common/sql/hooks/test_dbapi.py::TestDbApiHook::test_insert_rows_executemany
+-
tests/providers/common/sql/hooks/test_dbapi.py::TestDbApiHook::test_insert_rows_replace_executemany_hana_dialect
-
tests/providers/common/sql/hooks/test_dbapi.py::TestDbApiHook::test_instance_check_works_for_legacy_db_api_hook
-
tests/providers/common/sql/operators/test_sql.py::TestSQLCheckOperatorDbHook::test_get_hook
-
tests/providers/common/sql/operators/test_sql.py::TestSqlBranch::test_branch_false_with_dag_run
@@ -1059,6 +1061,10 @@
-
tests/providers/ssh/hooks/test_ssh.py::TestSSHHook::test_tunnel_without_password
-
tests/providers/tableau/hooks/test_tableau.py::TestTableauHook::test_get_conn_auth_via_token_and_site_in_init
-
tests/providers/tableau/hooks/test_tableau.py::TestTableauHook::test_get_conn_ssl_default
+-
tests/providers/teradata/hooks/test_teradata.py::TestTeradataHook::test_bulk_insert_rows_with_fields
+-
tests/providers/teradata/hooks/test_teradata.py::TestTeradataHook::test_bulk_insert_rows_with_commit_every
+-
tests/providers/teradata/hooks/test_teradata.py::TestTeradataHook::test_bulk_insert_rows_without_fields
+-
tests/providers/teradata/hooks/test_teradata.py::TestTeradataHook::test_bulk_insert_rows_no_rows
-
tests/providers/trino/operators/test_trino.py::test_execute_openlineage_events
-
tests/providers/vertica/operators/test_vertica.py::TestVerticaOperator::test_execute
-
tests/providers/weaviate/operators/test_weaviate.py::TestWeaviateIngestOperator::test_constructor
diff --git a/tests/providers/common/sql/hooks/test_dbapi.py
b/tests/providers/common/sql/hooks/test_dbapi.py
index fd9886345f..2c34ee133e 100644
--- a/tests/providers/common/sql/hooks/test_dbapi.py
+++ b/tests/providers/common/sql/hooks/test_dbapi.py
@@ -21,6 +21,7 @@ import json
from unittest import mock
import pytest
+from pyodbc import Cursor
from airflow.hooks.base import BaseHook
from airflow.models import Connection
@@ -39,7 +40,7 @@ class TestDbApiHook:
def setup_method(self, **kwargs):
self.cur = mock.MagicMock(
rowcount=0,
- spec=["description", "rowcount", "execute", "executemany",
"fetchall", "fetchone", "close"],
+ spec=Cursor,
)
self.conn = mock.MagicMock()
self.conn.cursor.return_value = self.cur
diff --git a/tests/providers/postgres/hooks/test_postgres.py
b/tests/providers/postgres/hooks/test_postgres.py
index 2d62cb4f43..8330ad3b1d 100644
--- a/tests/providers/postgres/hooks/test_postgres.py
+++ b/tests/providers/postgres/hooks/test_postgres.py
@@ -403,8 +403,7 @@ class TestPostgresHook:
assert commit_count == self.conn.commit.call_count
sql = f"INSERT INTO {table} VALUES (%s)"
- for row in rows:
- self.cur.execute.assert_any_call(sql, row)
+ self.cur.executemany.assert_any_call(sql, rows)
def test_insert_rows_replace(self):
table = "table"
@@ -432,8 +431,7 @@ class TestPostgresHook:
f"INSERT INTO {table} ({fields[0]}, {fields[1]}) VALUES (%s,%s) "
f"ON CONFLICT ({fields[0]}) DO UPDATE SET {fields[1]} =
excluded.{fields[1]}"
)
- for row in rows:
- self.cur.execute.assert_any_call(sql, row)
+ self.cur.executemany.assert_any_call(sql, rows)
def test_insert_rows_replace_missing_target_field_arg(self):
table = "table"
@@ -497,8 +495,7 @@ class TestPostgresHook:
f"INSERT INTO {table} ({', '.join(fields)}) VALUES (%s,%s) "
f"ON CONFLICT ({', '.join(fields)}) DO NOTHING"
)
- for row in rows:
- self.cur.execute.assert_any_call(sql, row)
+ self.cur.executemany.assert_any_call(sql, rows)
def test_rowcount(self):
hook = PostgresHook()
diff --git a/tests/providers/teradata/hooks/test_teradata.py
b/tests/providers/teradata/hooks/test_teradata.py
index d47e987f41..af77d66a63 100644
--- a/tests/providers/teradata/hooks/test_teradata.py
+++ b/tests/providers/teradata/hooks/test_teradata.py
@@ -226,9 +226,9 @@ class TestTeradataHook:
"str",
]
self.test_db_hook.insert_rows("table", rows, target_fields)
- self.cur.execute.assert_called_once_with(
+ self.cur.executemany.assert_called_once_with(
"INSERT INTO table (basestring, none, datetime, int, float, str)
VALUES (?,?,?,?,?,?)",
- ("'test_string", None, "2023-08-15T00:00:00", "1", "3.14", "str"),
+ [("'test_string", None, "2023-08-15T00:00:00", "1", "3.14",
"str")],
)
def test_bulk_insert_rows_with_fields(self):
@@ -236,7 +236,8 @@ class TestTeradataHook:
target_fields = ["col1", "col2", "col3"]
self.test_db_hook.bulk_insert_rows("table", rows, target_fields)
self.cur.executemany.assert_called_once_with(
- "INSERT INTO table (col1, col2, col3) VALUES (?, ?, ?)", rows
+ "INSERT INTO table (col1, col2, col3) VALUES (?,?,?)",
+ [("1", "2", "3"), ("4", "5", "6"), ("7", "8", "9")],
)
def test_bulk_insert_rows_with_commit_every(self):
@@ -244,19 +245,20 @@ class TestTeradataHook:
target_fields = ["col1", "col2", "col3"]
self.test_db_hook.bulk_insert_rows("table", rows, target_fields,
commit_every=2)
calls = [
- mock.call("INSERT INTO table (col1, col2, col3) values (1, 2, 3)"),
- mock.call("INSERT INTO table (col1, col2, col3) values (1, 2, 3)"),
- ]
- calls = [
- mock.call("INSERT INTO table (col1, col2, col3) VALUES (?, ?, ?)",
rows[:2]),
- mock.call("INSERT INTO table (col1, col2, col3) VALUES (?, ?, ?)",
rows[2:]),
+ mock.call(
+ "INSERT INTO table (col1, col2, col3) VALUES (?,?,?)", [("1",
"2", "3"), ("4", "5", "6")]
+ ),
+ mock.call("INSERT INTO table (col1, col2, col3) VALUES (?,?,?)",
[("7", "8", "9")]),
]
self.cur.executemany.assert_has_calls(calls, any_order=True)
def test_bulk_insert_rows_without_fields(self):
rows = [(1, 2, 3), (4, 5, 6), (7, 8, 9)]
self.test_db_hook.bulk_insert_rows("table", rows)
- self.cur.executemany.assert_called_once_with("INSERT INTO table
VALUES (?, ?, ?)", rows)
+ self.cur.executemany.assert_called_once_with(
+ "INSERT INTO table VALUES (?,?,?)",
+ [("1", "2", "3"), ("4", "5", "6"), ("7", "8", "9")],
+ )
def test_bulk_insert_rows_no_rows(self):
rows = []