This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 6e22364278 ydb provider: add database to table name in bulk upsert, 
use bulk upsert in system test (#41303)
6e22364278 is described below

commit 6e223642780799e7b726eff6e307f2d270b9c689
Author: uzhastik <[email protected]>
AuthorDate: Wed Aug 14 03:05:40 2024 +0300

    ydb provider: add database to table name in bulk upsert, use bulk upsert in 
system test (#41303)
    
    * use bulk upsert in system test
    
    * fix bulk_upsert test
---
 airflow/providers/ydb/hooks/ydb.py                 |  9 +++--
 .../providers/ydb/operators/test_ydb.py            |  2 +-
 tests/providers/ydb/operators/test_ydb.py          |  6 ++--
 tests/system/providers/ydb/example_ydb.py          | 38 ++++++++++++++++++----
 4 files changed, 39 insertions(+), 16 deletions(-)

diff --git a/airflow/providers/ydb/hooks/ydb.py 
b/airflow/providers/ydb/hooks/ydb.py
index fa82483e1c..e4580212d8 100644
--- a/airflow/providers/ydb/hooks/ydb.py
+++ b/airflow/providers/ydb/hooks/ydb.py
@@ -138,7 +138,7 @@ class YDBHook(DbApiHook):
         super().__init__(*args, **kwargs)
         self.is_ddl = is_ddl
 
-        conn: Connection = self.get_connection(getattr(self, 
self.conn_name_attr))
+        conn: Connection = self.get_connection(self.get_conn_id())
         host: str | None = conn.host
         if not host:
             raise ValueError("YDB host must be specified")
@@ -148,6 +148,7 @@ class YDBHook(DbApiHook):
         database: str | None = connection_extra.get("database")
         if not database:
             raise ValueError("YDB database must be specified")
+        self.database: str = database
 
         endpoint = f"{host}:{port}"
         credentials = get_credentials_from_connection(
@@ -222,15 +223,13 @@ class YDBHook(DbApiHook):
     @property
     def sqlalchemy_url(self) -> URL:
         conn: Connection = self.get_connection(self.get_conn_id())
-        connection_extra: dict[str, Any] = conn.extra_dejson
-        database: str | None = connection_extra.get("database")
         return URL.create(
             drivername="ydb",
             username=conn.login,
             password=conn.password,
             host=conn.host,
             port=conn.port,
-            query={"database": database},
+            query={"database": self.database},
         )
 
     def get_conn(self) -> YDBConnection:
@@ -249,7 +248,7 @@ class YDBHook(DbApiHook):
 
             https://ydb.tech/docs/en/recipes/ydb-sdk/bulk-upsert
         """
-        self.get_conn().bulk_upsert(table_name, rows, column_types)
+        self.get_conn().bulk_upsert(f"{self.database}/{table_name}", rows, 
column_types)
 
     @staticmethod
     def _get_table_client_settings() -> ydb.TableClientSettings:
diff --git a/tests/integration/providers/ydb/operators/test_ydb.py 
b/tests/integration/providers/ydb/operators/test_ydb.py
index fd81e282a7..b293ad1354 100644
--- a/tests/integration/providers/ydb/operators/test_ydb.py
+++ b/tests/integration/providers/ydb/operators/test_ydb.py
@@ -90,7 +90,7 @@ class TestYDBExecuteQueryOperator:
             {"id": 2, "name": "bears", "age": 22},
             {"id": 3, "name": "foxes", "age": 9},
         ]
-        hook.bulk_upsert("/local/team", rows=rows, column_types=column_types)
+        hook.bulk_upsert("team", rows=rows, column_types=column_types)
 
         result = age_sum_op.execute(self.mock_context)
         assert result == [(48,)]
diff --git a/tests/providers/ydb/operators/test_ydb.py 
b/tests/providers/ydb/operators/test_ydb.py
index a93e6fc1f5..d9e9340b59 100644
--- a/tests/providers/ydb/operators/test_ydb.py
+++ b/tests/providers/ydb/operators/test_ydb.py
@@ -118,7 +118,7 @@ class TestYDBExecuteQueryOperator:
         self, cursor_class, table_client_class, mock_session_pool, 
mock_driver, mock_get_connection
     ):
         mock_get_connection.return_value = Connection(
-            conn_type="ydb", host="localhost", extra={"database": "my_db"}
+            conn_type="ydb", host="localhost", extra={"database": "/my_db"}
         )
 
         cursor_class.return_value = FakeYDBCursor
@@ -156,8 +156,8 @@ class TestYDBExecuteQueryOperator:
             {"a": 1, "b": "hello"},
             {"a": 888, "b": "world"},
         ]
-        hook.bulk_upsert("/root/my_table", rows=rows, 
column_types=column_types)
+        hook.bulk_upsert("my_table", rows=rows, column_types=column_types)
         assert 
len(session_pool._pool_impl._driver.table_client.bulk_upsert_args) == 1
         arg0 = session_pool._pool_impl._driver.table_client.bulk_upsert_args[0]
-        assert arg0[0] == "/root/my_table"
+        assert arg0[0] == "/my_db/my_table"
         assert len(arg0[1]) == 2
diff --git a/tests/system/providers/ydb/example_ydb.py 
b/tests/system/providers/ydb/example_ydb.py
index 41df0333dd..8d43b6199a 100644
--- a/tests/system/providers/ydb/example_ydb.py
+++ b/tests/system/providers/ydb/example_ydb.py
@@ -19,7 +19,11 @@ from __future__ import annotations
 import datetime
 import os
 
+import ydb
+
 from airflow import DAG
+from airflow.decorators import task
+from airflow.providers.ydb.hooks.ydb import YDBHook
 from airflow.providers.ydb.operators.ydb import YDBExecuteQueryOperator
 
 # [START ydb_operator_howto_guide]
@@ -31,6 +35,26 @@ from airflow.providers.ydb.operators.ydb import 
YDBExecuteQueryOperator
 ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
 DAG_ID = "ydb_operator_dag"
 
+
+@task
+def populate_pet_table_via_bulk_upsert():
+    hook = YDBHook()
+    column_types = (
+        ydb.BulkUpsertColumns()
+        .add_column("pet_id", ydb.OptionalType(ydb.PrimitiveType.Int32))
+        .add_column("name", ydb.PrimitiveType.Utf8)
+        .add_column("pet_type", ydb.PrimitiveType.Utf8)
+        .add_column("birth_date", ydb.PrimitiveType.Utf8)
+        .add_column("owner", ydb.PrimitiveType.Utf8)
+    )
+
+    rows = [
+        {"pet_id": 3, "name": "Lester", "pet_type": "Hamster", "birth_date": 
"2020-06-23", "owner": "Lily"},
+        {"pet_id": 4, "name": "Quincy", "pet_type": "Parrot", "birth_date": 
"2013-08-11", "owner": "Anne"},
+    ]
+    hook.bulk_upsert("pet", rows=rows, column_types=column_types)
+
+
 with DAG(
     dag_id=DAG_ID,
     start_date=datetime.datetime(2020, 2, 2),
@@ -63,12 +87,6 @@ with DAG(
 
               UPSERT INTO pet (pet_id, name, pet_type, birth_date, owner)
               VALUES (2, 'Susie', 'Cat', '2019-05-01', 'Phil');
-
-              UPSERT INTO pet (pet_id, name, pet_type, birth_date, owner)
-              VALUES (3, 'Lester', 'Hamster', '2020-06-23', 'Lily');
-
-              UPSERT INTO pet (pet_id, name, pet_type, birth_date, owner)
-              VALUES (4, 'Quincy', 'Parrot', '2013-08-11', 'Anne');
             """,
     )
     # [END ydb_operator_howto_guide_populate_pet_table]
@@ -83,7 +101,13 @@ with DAG(
     )
     # [END ydb_operator_howto_guide_get_birth_date]
 
-    create_pet_table >> populate_pet_table >> get_all_pets >> get_birth_date
+    (
+        create_pet_table
+        >> populate_pet_table
+        >> populate_pet_table_via_bulk_upsert()
+        >> get_all_pets
+        >> get_birth_date
+    )
     # [END ydb_operator_howto_guide]
 
     from tests.system.utils.watcher import watcher

Reply via email to