This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 33c72387cac Add sequence insert support to OracleHook (#42947)
33c72387cac is described below

commit 33c72387caccbb99f3182dd344762cc46bd36530
Author: Lee2532 <[email protected]>
AuthorDate: Sun Oct 27 04:58:55 2024 +0900

    Add sequence insert support to OracleHook (#42947)
    
    * FEAT : oracle sequence column insert
    
    * FEAT : Exception case if you enter only part of it
    
    * FIX : pythonic  code
---
 .../src/airflow/providers/oracle/hooks/oracle.py   | 31 ++++++++++++++++++----
 providers/tests/oracle/hooks/test_oracle.py        | 30 +++++++++++++++++++++
 2 files changed, 56 insertions(+), 5 deletions(-)

diff --git a/providers/src/airflow/providers/oracle/hooks/oracle.py 
b/providers/src/airflow/providers/oracle/hooks/oracle.py
index a252a7599cd..3c51fe31c8a 100644
--- a/providers/src/airflow/providers/oracle/hooks/oracle.py
+++ b/providers/src/airflow/providers/oracle/hooks/oracle.py
@@ -328,6 +328,8 @@ class OracleHook(DbApiHook):
         rows: list[tuple],
         target_fields: list[str] | None = None,
         commit_every: int = 5000,
+        sequence_column: str | None = None,
+        sequence_name: str | None = None,
     ):
         """
         Perform bulk inserts efficiently for Oracle DB.
@@ -342,6 +344,8 @@ class OracleHook(DbApiHook):
             If None, each rows should have some order as table columns name
         :param commit_every: the maximum number of rows to insert in one 
transaction
             Default 5000. Set greater than 0. Set 1 to insert each row in each 
transaction
+        :param sequence_column: the column name to which the sequence will be 
applied, default None.
+        :param sequence_name: the names of the sequence_name in the table, 
default None.
         """
         if not rows:
             raise ValueError("parameter rows could not be None or empty 
iterable")
@@ -350,11 +354,28 @@ class OracleHook(DbApiHook):
             self.set_autocommit(conn, False)
         cursor = conn.cursor()  # type: ignore[attr-defined]
         values_base = target_fields or rows[0]
-        prepared_stm = "insert into {tablename} {columns} values 
({values})".format(
-            tablename=table,
-            columns="({})".format(", ".join(target_fields)) if target_fields 
else "",
-            values=", ".join(f":{i}" for i in range(1, len(values_base) + 1)),
-        )
+
+        if bool(sequence_column) ^ bool(sequence_name):
+            raise ValueError(
+                "Parameters 'sequence_column' and 'sequence_name' must be 
provided together or not at all."
+            )
+
+        if sequence_column and sequence_name:
+            prepared_stm = "insert into {tablename} {columns} values 
({values})".format(
+                tablename=table,
+                columns="({})".format(", ".join([sequence_column] + 
target_fields))
+                if target_fields
+                else f"({sequence_column})",
+                values=", ".join(
+                    [f"{sequence_name}.NEXTVAL"] + [f":{i}" for i in range(1, 
len(values_base) + 1)]
+                ),
+            )
+        else:
+            prepared_stm = "insert into {tablename} {columns} values 
({values})".format(
+                tablename=table,
+                columns="({})".format(", ".join(target_fields)) if 
target_fields else "",
+                values=", ".join(f":{i}" for i in range(1, len(values_base) + 
1)),
+            )
         row_count = 0
         # Chunk the rows
         row_chunk = []
diff --git a/providers/tests/oracle/hooks/test_oracle.py 
b/providers/tests/oracle/hooks/test_oracle.py
index fc4709020eb..2650d8f7ca9 100644
--- a/providers/tests/oracle/hooks/test_oracle.py
+++ b/providers/tests/oracle/hooks/test_oracle.py
@@ -369,6 +369,36 @@ class TestOracleHook:
         with pytest.raises(ValueError):
             self.db_hook.bulk_insert_rows("table", rows)
 
+    def test_bulk_insert_sequence_field(self):
+        rows = [(1, 2, 3), (4, 5, 6), (7, 8, 9)]
+        target_fields = ["col1", "col2", "col3"]
+        sequence_column = "id"
+        sequence_name = "my_sequence"
+        self.db_hook.bulk_insert_rows(
+            "table", rows, target_fields, sequence_column=sequence_column, 
sequence_name=sequence_name
+        )
+        self.cur.prepare.assert_called_once_with(
+            "insert into table (id, col1, col2, col3) values 
(my_sequence.NEXTVAL, :1, :2, :3)"
+        )
+        self.cur.executemany.assert_called_once_with(None, rows)
+
+    def test_bulk_insert_sequence_without_parameter(self):
+        rows = [(1, 2, 3), (4, 5, 6), (7, 8, 9)]
+        target_fields = ["col1", "col2", "col3"]
+        sequence_column = "id"
+        sequence_name = None
+        with pytest.raises(ValueError):
+            self.db_hook.bulk_insert_rows(
+                "table", rows, target_fields, sequence_column=sequence_column, 
sequence_name=sequence_name
+            )
+
+        sequence_column = None
+        sequence_name = "my_sequence"
+        with pytest.raises(ValueError):
+            self.db_hook.bulk_insert_rows(
+                "table", rows, target_fields, sequence_column=sequence_column, 
sequence_name=sequence_name
+            )
+
     def test_callproc_none(self):
         parameters = None
 

Reply via email to