This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 33c72387cac Add sequence insert support to OracleHook (#42947)
33c72387cac is described below
commit 33c72387caccbb99f3182dd344762cc46bd36530
Author: Lee2532 <[email protected]>
AuthorDate: Sun Oct 27 04:58:55 2024 +0900
Add sequence insert support to OracleHook (#42947)
* FEAT : oracle sequence column insert
* FEAT : Exception case if you enter only part of it
* FIX : pythonic code
---
.../src/airflow/providers/oracle/hooks/oracle.py | 31 ++++++++++++++++++----
providers/tests/oracle/hooks/test_oracle.py | 30 +++++++++++++++++++++
2 files changed, 56 insertions(+), 5 deletions(-)
diff --git a/providers/src/airflow/providers/oracle/hooks/oracle.py
b/providers/src/airflow/providers/oracle/hooks/oracle.py
index a252a7599cd..3c51fe31c8a 100644
--- a/providers/src/airflow/providers/oracle/hooks/oracle.py
+++ b/providers/src/airflow/providers/oracle/hooks/oracle.py
@@ -328,6 +328,8 @@ class OracleHook(DbApiHook):
rows: list[tuple],
target_fields: list[str] | None = None,
commit_every: int = 5000,
+ sequence_column: str | None = None,
+ sequence_name: str | None = None,
):
"""
Perform bulk inserts efficiently for Oracle DB.
@@ -342,6 +344,8 @@ class OracleHook(DbApiHook):
If None, each rows should have some order as table columns name
:param commit_every: the maximum number of rows to insert in one
transaction
Default 5000. Set greater than 0. Set 1 to insert each row in each
transaction
+ :param sequence_column: the column name to which the sequence will be
applied, default None.
+ :param sequence_name: the names of the sequence_name in the table,
default None.
"""
if not rows:
raise ValueError("parameter rows could not be None or empty
iterable")
@@ -350,11 +354,28 @@ class OracleHook(DbApiHook):
self.set_autocommit(conn, False)
cursor = conn.cursor() # type: ignore[attr-defined]
values_base = target_fields or rows[0]
- prepared_stm = "insert into {tablename} {columns} values
({values})".format(
- tablename=table,
- columns="({})".format(", ".join(target_fields)) if target_fields
else "",
- values=", ".join(f":{i}" for i in range(1, len(values_base) + 1)),
- )
+
+ if bool(sequence_column) ^ bool(sequence_name):
+ raise ValueError(
+ "Parameters 'sequence_column' and 'sequence_name' must be
provided together or not at all."
+ )
+
+ if sequence_column and sequence_name:
+ prepared_stm = "insert into {tablename} {columns} values
({values})".format(
+ tablename=table,
+ columns="({})".format(", ".join([sequence_column] +
target_fields))
+ if target_fields
+ else f"({sequence_column})",
+ values=", ".join(
+ [f"{sequence_name}.NEXTVAL"] + [f":{i}" for i in range(1,
len(values_base) + 1)]
+ ),
+ )
+ else:
+ prepared_stm = "insert into {tablename} {columns} values
({values})".format(
+ tablename=table,
+ columns="({})".format(", ".join(target_fields)) if
target_fields else "",
+ values=", ".join(f":{i}" for i in range(1, len(values_base) +
1)),
+ )
row_count = 0
# Chunk the rows
row_chunk = []
diff --git a/providers/tests/oracle/hooks/test_oracle.py
b/providers/tests/oracle/hooks/test_oracle.py
index fc4709020eb..2650d8f7ca9 100644
--- a/providers/tests/oracle/hooks/test_oracle.py
+++ b/providers/tests/oracle/hooks/test_oracle.py
@@ -369,6 +369,36 @@ class TestOracleHook:
with pytest.raises(ValueError):
self.db_hook.bulk_insert_rows("table", rows)
+ def test_bulk_insert_sequence_field(self):
+ rows = [(1, 2, 3), (4, 5, 6), (7, 8, 9)]
+ target_fields = ["col1", "col2", "col3"]
+ sequence_column = "id"
+ sequence_name = "my_sequence"
+ self.db_hook.bulk_insert_rows(
+ "table", rows, target_fields, sequence_column=sequence_column,
sequence_name=sequence_name
+ )
+ self.cur.prepare.assert_called_once_with(
+ "insert into table (id, col1, col2, col3) values
(my_sequence.NEXTVAL, :1, :2, :3)"
+ )
+ self.cur.executemany.assert_called_once_with(None, rows)
+
+ def test_bulk_insert_sequence_without_parameter(self):
+ rows = [(1, 2, 3), (4, 5, 6), (7, 8, 9)]
+ target_fields = ["col1", "col2", "col3"]
+ sequence_column = "id"
+ sequence_name = None
+ with pytest.raises(ValueError):
+ self.db_hook.bulk_insert_rows(
+ "table", rows, target_fields, sequence_column=sequence_column,
sequence_name=sequence_name
+ )
+
+ sequence_column = None
+ sequence_name = "my_sequence"
+ with pytest.raises(ValueError):
+ self.db_hook.bulk_insert_rows(
+ "table", rows, target_fields, sequence_column=sequence_column,
sequence_name=sequence_name
+ )
+
def test_callproc_none(self):
parameters = None