This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 7ab24c7723 Always use the executemany method when inserting rows in 
DbApiHook as it's way much faster (#38715)
7ab24c7723 is described below

commit 7ab24c7723c65c90626b10db63444b88c0380e14
Author: David Blain <i...@dabla.be>
AuthorDate: Fri Apr 12 11:43:23 2024 +0200

    Always use the executemany method when inserting rows in DbApiHook as it's 
way much faster (#38715)
    
    
    ---------
    
    Co-authored-by: David Blain <david.bl...@infrabel.be>
    Co-authored-by: Tzu-ping Chung <uranu...@gmail.com>
---
 airflow/providers/common/sql/hooks/sql.py       | 52 ++++++++++++++-----------
 airflow/providers/common/sql/hooks/sql.pyi      |  1 +
 airflow/providers/odbc/hooks/odbc.py            |  1 +
 airflow/providers/postgres/hooks/postgres.py    |  1 +
 airflow/providers/teradata/hooks/teradata.py    | 48 ++++++++---------------
 tests/deprecations_ignore.yml                   |  6 +++
 tests/providers/common/sql/hooks/test_dbapi.py  |  3 +-
 tests/providers/postgres/hooks/test_postgres.py |  9 ++---
 tests/providers/teradata/hooks/test_teradata.py | 22 ++++++-----
 9 files changed, 73 insertions(+), 70 deletions(-)

diff --git a/airflow/providers/common/sql/hooks/sql.py 
b/airflow/providers/common/sql/hooks/sql.py
index 7f1536a39b..4625c2e014 100644
--- a/airflow/providers/common/sql/hooks/sql.py
+++ b/airflow/providers/common/sql/hooks/sql.py
@@ -18,7 +18,7 @@ from __future__ import annotations
 
 import contextlib
 import warnings
-from contextlib import closing
+from contextlib import closing, contextmanager
 from datetime import datetime
 from typing import (
     TYPE_CHECKING,
@@ -147,6 +147,8 @@ class DbApiHook(BaseHook):
     default_conn_name = "default_conn_id"
     # Override if this db supports autocommit.
     supports_autocommit = False
+    # Override if this db supports executemany.
+    supports_executemany = False
     # Override with the object that exposes the connect method
     connector: ConnectorProtocol | None = None
     # Override with db-specific query to check connection
@@ -408,10 +410,7 @@ class DbApiHook(BaseHook):
         else:
             raise ValueError("List of SQL statements is empty")
         _last_result = None
-        with closing(self.get_conn()) as conn:
-            if self.supports_autocommit:
-                self.set_autocommit(conn, autocommit)
-
+        with self._create_autocommit_connection(autocommit) as conn:
             with closing(conn.cursor()) as cur:
                 results = []
                 for sql_statement in sql_list:
@@ -528,6 +527,14 @@ class DbApiHook(BaseHook):
 
         return self._replace_statement_format.format(table, target_fields, 
",".join(placeholders))
 
+    @contextmanager
+    def _create_autocommit_connection(self, autocommit: bool = False):
+        """Context manager that closes the connection after use and detects if 
autocommit is supported."""
+        with closing(self.get_conn()) as conn:
+            if self.supports_autocommit:
+                self.set_autocommit(conn, autocommit)
+            yield conn
+
     def insert_rows(
         self,
         table,
@@ -550,47 +557,48 @@ class DbApiHook(BaseHook):
         :param commit_every: The maximum number of rows to insert in one
             transaction. Set to 0 to insert all rows in one transaction.
         :param replace: Whether to replace instead of insert
-        :param executemany: Insert all rows at once in chunks defined by the 
commit_every parameter, only
-            works if all rows have same number of column names but leads to 
better performance
+        :param executemany: (Deprecated) If True, all rows are inserted at 
once in
+            chunks defined by the commit_every parameter. This only works if 
all rows
+            have same number of column names, but leads to better performance.
         """
-        i = 0
-        with closing(self.get_conn()) as conn:
-            if self.supports_autocommit:
-                self.set_autocommit(conn, False)
+        if executemany:
+            warnings.warn(
+                "executemany parameter is deprecated, override 
supports_executemany instead.",
+                AirflowProviderDeprecationWarning,
+                stacklevel=2,
+            )
 
+        with self._create_autocommit_connection() as conn:
             conn.commit()
-
             with closing(conn.cursor()) as cur:
-                if executemany:
+                if self.supports_executemany or executemany:
                     for chunked_rows in chunked(rows, commit_every):
                         values = list(
                             map(
-                                lambda row: tuple(map(lambda cell: 
self._serialize_cell(cell, conn), row)),
+                                lambda row: self._serialize_cells(row, conn),
                                 chunked_rows,
                             )
                         )
                         sql = self._generate_insert_sql(table, values[0], 
target_fields, replace, **kwargs)
                         self.log.debug("Generated sql: %s", sql)
-                        cur.fast_executemany = True
                         cur.executemany(sql, values)
                         conn.commit()
                         self.log.info("Loaded %s rows into %s so far", 
len(chunked_rows), table)
                 else:
                     for i, row in enumerate(rows, 1):
-                        lst = []
-                        for cell in row:
-                            lst.append(self._serialize_cell(cell, conn))
-                        values = tuple(lst)
+                        values = self._serialize_cells(row, conn)
                         sql = self._generate_insert_sql(table, values, 
target_fields, replace, **kwargs)
                         self.log.debug("Generated sql: %s", sql)
                         cur.execute(sql, values)
                         if commit_every and i % commit_every == 0:
                             conn.commit()
                             self.log.info("Loaded %s rows into %s so far", i, 
table)
+                    conn.commit()
+        self.log.info("Done loading. Loaded a total of %s rows into %s", 
len(rows), table)
 
-            if not executemany:
-                conn.commit()
-        self.log.info("Done loading. Loaded a total of %s rows into %s", i, 
table)
+    @classmethod
+    def _serialize_cells(cls, row, conn=None):
+        return tuple(cls._serialize_cell(cell, conn) for cell in row)
 
     @staticmethod
     def _serialize_cell(cell, conn=None) -> str | None:
diff --git a/airflow/providers/common/sql/hooks/sql.pyi 
b/airflow/providers/common/sql/hooks/sql.pyi
index 83135a235b..85edd625f9 100644
--- a/airflow/providers/common/sql/hooks/sql.pyi
+++ b/airflow/providers/common/sql/hooks/sql.pyi
@@ -57,6 +57,7 @@ class DbApiHook(BaseHook):
     conn_name_attr: str
     default_conn_name: str
     supports_autocommit: bool
+    supports_executemany: bool
     connector: ConnectorProtocol | None
     log_sql: Incomplete
     descriptions: Incomplete
diff --git a/airflow/providers/odbc/hooks/odbc.py 
b/airflow/providers/odbc/hooks/odbc.py
index 8cf95bf095..53c4cf207a 100644
--- a/airflow/providers/odbc/hooks/odbc.py
+++ b/airflow/providers/odbc/hooks/odbc.py
@@ -56,6 +56,7 @@ class OdbcHook(DbApiHook):
     conn_type = "odbc"
     hook_name = "ODBC"
     supports_autocommit = True
+    supports_executemany = True
 
     default_driver: str | None = None
 
diff --git a/airflow/providers/postgres/hooks/postgres.py 
b/airflow/providers/postgres/hooks/postgres.py
index 0afb7740fe..9e1b3a83d7 100644
--- a/airflow/providers/postgres/hooks/postgres.py
+++ b/airflow/providers/postgres/hooks/postgres.py
@@ -74,6 +74,7 @@ class PostgresHook(DbApiHook):
     conn_type = "postgres"
     hook_name = "Postgres"
     supports_autocommit = True
+    supports_executemany = True
 
     def __init__(self, *args, options: str | None = None, **kwargs) -> None:
         if "schema" in kwargs:
diff --git a/airflow/providers/teradata/hooks/teradata.py 
b/airflow/providers/teradata/hooks/teradata.py
index 73c4fb8ff0..3afc32bc74 100644
--- a/airflow/providers/teradata/hooks/teradata.py
+++ b/airflow/providers/teradata/hooks/teradata.py
@@ -19,12 +19,14 @@
 
 from __future__ import annotations
 
+import warnings
 from typing import TYPE_CHECKING, Any
 
 import sqlalchemy
 import teradatasql
 from teradatasql import TeradataConnection
 
+from airflow.exceptions import AirflowProviderDeprecationWarning
 from airflow.providers.common.sql.hooks.sql import DbApiHook
 
 if TYPE_CHECKING:
@@ -59,6 +61,9 @@ class TeradataHook(DbApiHook):
     # Override if this db supports autocommit.
     supports_autocommit = True
 
+    # Override if this db supports executemany.
+    supports_executemany = True
+
     # Override this for hook to have a custom name in the UI selection
     conn_type = "teradata"
 
@@ -97,7 +102,9 @@ class TeradataHook(DbApiHook):
         target_fields: list[str] | None = None,
         commit_every: int = 5000,
     ):
-        """Insert bulk of records into Teradata SQL Database.
+        """Use :func:`insert_rows` instead, this is deprecated.
+
+        Insert bulk of records into Teradata SQL Database.
 
         This uses prepared statements via `executemany()`. For best 
performance,
         pass in `rows` as an iterator.
@@ -106,41 +113,20 @@ class TeradataHook(DbApiHook):
             specific database
         :param rows: the rows to insert into the table
         :param target_fields: the names of the columns to fill in the table, 
default None.
-            If None, each rows should have some order as table columns name
+            If None, each row should have some order as table columns name
         :param commit_every: the maximum number of rows to insert in one 
transaction
             Default 5000. Set greater than 0. Set 1 to insert each row in each 
transaction
         """
+        warnings.warn(
+            "bulk_insert_rows is deprecated. Please use the insert_rows method 
instead.",
+            AirflowProviderDeprecationWarning,
+            stacklevel=2,
+        )
+
         if not rows:
             raise ValueError("parameter rows could not be None or empty 
iterable")
-        conn = self.get_conn()
-        if self.supports_autocommit:
-            self.set_autocommit(conn, False)
-        cursor = conn.cursor()
-        cursor.fast_executemany = True
-        values_base = target_fields if target_fields else rows[0]
-        prepared_stm = "INSERT INTO {tablename} {columns} VALUES 
({values})".format(
-            tablename=table,
-            columns="({})".format(", ".join(target_fields)) if target_fields 
else "",
-            values=", ".join("?" for i in range(1, len(values_base) + 1)),
-        )
-        row_count = 0
-        # Chunk the rows
-        row_chunk = []
-        for row in rows:
-            row_chunk.append(row)
-            row_count += 1
-            if row_count % commit_every == 0:
-                cursor.executemany(prepared_stm, row_chunk)
-                conn.commit()  # type: ignore[attr-defined]
-                # Empty chunk
-                row_chunk = []
-        # Commit the leftover chunk
-        if len(row_chunk) > 0:
-            cursor.executemany(prepared_stm, row_chunk)
-            conn.commit()  # type: ignore[attr-defined]
-        self.log.info("[%s] inserted %s rows", table, row_count)
-        cursor.close()
-        conn.close()  # type: ignore[attr-defined]
+
+        self.insert_rows(table=table, rows=rows, target_fields=target_fields, 
commit_every=commit_every)
 
     def _get_conn_config_teradatasql(self) -> dict[str, Any]:
         """Return set of config params required for connecting to Teradata DB 
using teradatasql client."""
diff --git a/tests/deprecations_ignore.yml b/tests/deprecations_ignore.yml
index c88d2b9368..c8f02a193d 100644
--- a/tests/deprecations_ignore.yml
+++ b/tests/deprecations_ignore.yml
@@ -659,6 +659,8 @@
 - 
tests/providers/cncf/kubernetes/test_pod_generator.py::TestPodGenerator::test_pod_name_confirm_to_max_length
 - 
tests/providers/cncf/kubernetes/test_pod_generator.py::TestPodGenerator::test_pod_name_is_valid
 - 
tests/providers/cncf/kubernetes/test_template_rendering.py::test_render_k8s_pod_yaml
+- 
tests/providers/common/sql/hooks/test_dbapi.py::TestDbApiHook::test_insert_rows_executemany
+- 
tests/providers/common/sql/hooks/test_dbapi.py::TestDbApiHook::test_insert_rows_replace_executemany_hana_dialect
 - 
tests/providers/common/sql/hooks/test_dbapi.py::TestDbApiHook::test_instance_check_works_for_legacy_db_api_hook
 - 
tests/providers/common/sql/operators/test_sql.py::TestSQLCheckOperatorDbHook::test_get_hook
 - 
tests/providers/common/sql/operators/test_sql.py::TestSqlBranch::test_branch_false_with_dag_run
@@ -1059,6 +1061,10 @@
 - 
tests/providers/ssh/hooks/test_ssh.py::TestSSHHook::test_tunnel_without_password
 - 
tests/providers/tableau/hooks/test_tableau.py::TestTableauHook::test_get_conn_auth_via_token_and_site_in_init
 - 
tests/providers/tableau/hooks/test_tableau.py::TestTableauHook::test_get_conn_ssl_default
+- 
tests/providers/teradata/hooks/test_teradata.py::TestTeradataHook::test_bulk_insert_rows_with_fields
+- 
tests/providers/teradata/hooks/test_teradata.py::TestTeradataHook::test_bulk_insert_rows_with_commit_every
+- 
tests/providers/teradata/hooks/test_teradata.py::TestTeradataHook::test_bulk_insert_rows_without_fields
+- 
tests/providers/teradata/hooks/test_teradata.py::TestTeradataHook::test_bulk_insert_rows_no_rows
 - 
tests/providers/trino/operators/test_trino.py::test_execute_openlineage_events
 - 
tests/providers/vertica/operators/test_vertica.py::TestVerticaOperator::test_execute
 - 
tests/providers/weaviate/operators/test_weaviate.py::TestWeaviateIngestOperator::test_constructor
diff --git a/tests/providers/common/sql/hooks/test_dbapi.py 
b/tests/providers/common/sql/hooks/test_dbapi.py
index fd9886345f..2c34ee133e 100644
--- a/tests/providers/common/sql/hooks/test_dbapi.py
+++ b/tests/providers/common/sql/hooks/test_dbapi.py
@@ -21,6 +21,7 @@ import json
 from unittest import mock
 
 import pytest
+from pyodbc import Cursor
 
 from airflow.hooks.base import BaseHook
 from airflow.models import Connection
@@ -39,7 +40,7 @@ class TestDbApiHook:
     def setup_method(self, **kwargs):
         self.cur = mock.MagicMock(
             rowcount=0,
-            spec=["description", "rowcount", "execute", "executemany", 
"fetchall", "fetchone", "close"],
+            spec=Cursor,
         )
         self.conn = mock.MagicMock()
         self.conn.cursor.return_value = self.cur
diff --git a/tests/providers/postgres/hooks/test_postgres.py 
b/tests/providers/postgres/hooks/test_postgres.py
index 2d62cb4f43..8330ad3b1d 100644
--- a/tests/providers/postgres/hooks/test_postgres.py
+++ b/tests/providers/postgres/hooks/test_postgres.py
@@ -403,8 +403,7 @@ class TestPostgresHook:
         assert commit_count == self.conn.commit.call_count
 
         sql = f"INSERT INTO {table}  VALUES (%s)"
-        for row in rows:
-            self.cur.execute.assert_any_call(sql, row)
+        self.cur.executemany.assert_any_call(sql, rows)
 
     def test_insert_rows_replace(self):
         table = "table"
@@ -432,8 +431,7 @@ class TestPostgresHook:
             f"INSERT INTO {table} ({fields[0]}, {fields[1]}) VALUES (%s,%s) "
             f"ON CONFLICT ({fields[0]}) DO UPDATE SET {fields[1]} = 
excluded.{fields[1]}"
         )
-        for row in rows:
-            self.cur.execute.assert_any_call(sql, row)
+        self.cur.executemany.assert_any_call(sql, rows)
 
     def test_insert_rows_replace_missing_target_field_arg(self):
         table = "table"
@@ -497,8 +495,7 @@ class TestPostgresHook:
             f"INSERT INTO {table} ({', '.join(fields)}) VALUES (%s,%s) "
             f"ON CONFLICT ({', '.join(fields)}) DO NOTHING"
         )
-        for row in rows:
-            self.cur.execute.assert_any_call(sql, row)
+        self.cur.executemany.assert_any_call(sql, rows)
 
     def test_rowcount(self):
         hook = PostgresHook()
diff --git a/tests/providers/teradata/hooks/test_teradata.py 
b/tests/providers/teradata/hooks/test_teradata.py
index d47e987f41..af77d66a63 100644
--- a/tests/providers/teradata/hooks/test_teradata.py
+++ b/tests/providers/teradata/hooks/test_teradata.py
@@ -226,9 +226,9 @@ class TestTeradataHook:
             "str",
         ]
         self.test_db_hook.insert_rows("table", rows, target_fields)
-        self.cur.execute.assert_called_once_with(
+        self.cur.executemany.assert_called_once_with(
             "INSERT INTO table (basestring, none, datetime, int, float, str) 
VALUES (?,?,?,?,?,?)",
-            ("'test_string", None, "2023-08-15T00:00:00", "1", "3.14", "str"),
+            [("'test_string", None, "2023-08-15T00:00:00", "1", "3.14", 
"str")],
         )
 
     def test_bulk_insert_rows_with_fields(self):
@@ -236,7 +236,8 @@ class TestTeradataHook:
         target_fields = ["col1", "col2", "col3"]
         self.test_db_hook.bulk_insert_rows("table", rows, target_fields)
         self.cur.executemany.assert_called_once_with(
-            "INSERT INTO table (col1, col2, col3) VALUES (?, ?, ?)", rows
+            "INSERT INTO table (col1, col2, col3) VALUES (?,?,?)",
+            [("1", "2", "3"), ("4", "5", "6"), ("7", "8", "9")],
         )
 
     def test_bulk_insert_rows_with_commit_every(self):
@@ -244,19 +245,20 @@ class TestTeradataHook:
         target_fields = ["col1", "col2", "col3"]
         self.test_db_hook.bulk_insert_rows("table", rows, target_fields, 
commit_every=2)
         calls = [
-            mock.call("INSERT INTO table (col1, col2, col3) values (1, 2, 3)"),
-            mock.call("INSERT INTO table (col1, col2, col3) values (1, 2, 3)"),
-        ]
-        calls = [
-            mock.call("INSERT INTO table (col1, col2, col3) VALUES (?, ?, ?)", 
rows[:2]),
-            mock.call("INSERT INTO table (col1, col2, col3) VALUES (?, ?, ?)", 
rows[2:]),
+            mock.call(
+                "INSERT INTO table (col1, col2, col3) VALUES (?,?,?)", [("1", 
"2", "3"), ("4", "5", "6")]
+            ),
+            mock.call("INSERT INTO table (col1, col2, col3) VALUES (?,?,?)", 
[("7", "8", "9")]),
         ]
         self.cur.executemany.assert_has_calls(calls, any_order=True)
 
     def test_bulk_insert_rows_without_fields(self):
         rows = [(1, 2, 3), (4, 5, 6), (7, 8, 9)]
         self.test_db_hook.bulk_insert_rows("table", rows)
-        self.cur.executemany.assert_called_once_with("INSERT INTO table  
VALUES (?, ?, ?)", rows)
+        self.cur.executemany.assert_called_once_with(
+            "INSERT INTO table  VALUES (?,?,?)",
+            [("1", "2", "3"), ("4", "5", "6"), ("7", "8", "9")],
+        )
 
     def test_bulk_insert_rows_no_rows(self):
         rows = []

Reply via email to