This is an automated email from the ASF dual-hosted git repository. taragolis pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 4ada175e3d fix: use `sqlalchemy_url` property in `get_uri` for postgresql provider (#38831) 4ada175e3d is described below commit 4ada175e3dc75b92dd13840d449917772a9b7c89 Author: Kalyan <kalyan.be...@live.com> AuthorDate: Thu May 9 12:52:28 2024 +0530 fix: use `sqlalchemy_url` property in `get_uri` for postgresql provider (#38831) * update get_uri * update get_uri Signed-off-by: kalyanr <kalyan.be...@live.com> * update docstring Signed-off-by: kalyanr <kalyan.be...@live.com> * add and use sa_uri property * update database in sa_uri * update tests * remove client_encoding from test_get_uri * use sqlalchemy_url property * add default port * update tests * update usage of ports * revert client_encoding updates --------- Signed-off-by: kalyanr <kalyan.be...@live.com> --- airflow/providers/postgres/hooks/postgres.py | 20 +++++++++++++++----- tests/providers/postgres/hooks/test_postgres.py | 4 ++-- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/airflow/providers/postgres/hooks/postgres.py b/airflow/providers/postgres/hooks/postgres.py index 9e1b3a83d7..6a24cb0316 100644 --- a/airflow/providers/postgres/hooks/postgres.py +++ b/airflow/providers/postgres/hooks/postgres.py @@ -28,6 +28,7 @@ import psycopg2.extensions import psycopg2.extras from deprecated import deprecated from psycopg2.extras import DictCursor, NamedTupleCursor, RealDictCursor +from sqlalchemy.engine import URL from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.common.sql.hooks.sql import DbApiHook @@ -113,6 +114,18 @@ class PostgresHook(DbApiHook): def schema(self, value): self.database = value + @property + def sqlalchemy_url(self) -> URL: + conn = self.get_connection(getattr(self, self.conn_name_attr)) + return URL.create( + drivername="postgresql", + username=conn.login, + password=conn.password, + host=conn.host, + port=conn.port, + database=self.database or conn.schema, + ) + def _get_cursor(self, raw_cursor: str) -> CursorType: _cursor = raw_cursor.lower() cursor_types = { @@ -186,12 +199,9 @@ class PostgresHook(DbApiHook): def get_uri(self) -> str: """Extract the URI from the connection. - :return: the extracted uri. + :return: the extracted URI in Sqlalchemy URI format. """ - conn = self.get_connection(getattr(self, self.conn_name_attr)) - conn.schema = self.database or conn.schema - uri = conn.get_uri().replace("postgres://", "postgresql://") - return uri + return self.sqlalchemy_url.render_as_string(hide_password=False) def bulk_load(self, table: str, tmp_file: str) -> None: """Load a tab-delimited file into a database table.""" diff --git a/tests/providers/postgres/hooks/test_postgres.py b/tests/providers/postgres/hooks/test_postgres.py index 8330ad3b1d..78d3414ab0 100644 --- a/tests/providers/postgres/hooks/test_postgres.py +++ b/tests/providers/postgres/hooks/test_postgres.py @@ -58,11 +58,11 @@ class TestPostgresHookConn: @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect") def test_get_uri(self, mock_connect): - self.connection.extra = json.dumps({"client_encoding": "utf-8"}) self.connection.conn_type = "postgres" + self.connection.port = 5432 self.db_hook.get_conn() assert mock_connect.call_count == 1 - assert self.db_hook.get_uri() == "postgresql://login:password@host/database?client_encoding=utf-8" + assert self.db_hook.get_uri() == "postgresql://login:password@host:5432/database" @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect") def test_get_conn_cursor(self, mock_connect):