This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 248e0a2eca Resolve postgres deprecations in tests (#40392)
248e0a2eca is described below

commit 248e0a2ecab130a39306cf99af329dcbdff9e60d
Author: Gopal Dirisala <[email protected]>
AuthorDate: Mon Jun 24 00:24:39 2024 +0530

    Resolve postgres deprecations in tests (#40392)
    
    * Resolve postgres deprecations in tests
    
    * Resolve postgres deprecations in tests
---
 airflow/providers/common/sql/operators/sql.py      |  5 ++-
 tests/deprecations_ignore.yml                      |  9 ----
 tests/providers/common/sql/operators/test_sql.py   | 11 +----
 tests/providers/postgres/hooks/test_postgres.py    |  7 +++-
 .../providers/postgres/operators/test_postgres.py  | 49 +++++++++++++++-------
 5 files changed, 47 insertions(+), 34 deletions(-)

diff --git a/airflow/providers/common/sql/operators/sql.py 
b/airflow/providers/common/sql/operators/sql.py
index d50a6bf0f5..72c750d766 100644
--- a/airflow/providers/common/sql/operators/sql.py
+++ b/airflow/providers/common/sql/operators/sql.py
@@ -177,7 +177,10 @@ class BaseSQLOperator(BaseOperator):
             )
 
         if self.database:
-            hook.schema = self.database
+            if hook.conn_type == "postgres":
+                hook.database = self.database
+            else:
+                hook.schema = self.database
 
         return hook
 
diff --git a/tests/deprecations_ignore.yml b/tests/deprecations_ignore.yml
index 305b055c92..cf3ad35905 100644
--- a/tests/deprecations_ignore.yml
+++ b/tests/deprecations_ignore.yml
@@ -235,15 +235,6 @@
 - 
tests/providers/mysql/operators/test_mysql.py::TestMySql::test_mysql_operator_test_multi
 - 
tests/providers/mysql/operators/test_mysql.py::TestMySql::test_overwrite_schema
 - 
tests/providers/mysql/operators/test_mysql.py::test_execute_openlineage_events
-- 
tests/providers/postgres/hooks/test_postgres.py::TestPostgresHookConn::test_schema_kwarg_database_kwarg_compatibility
-- 
tests/providers/postgres/operators/test_postgres.py::test_parameters_are_templatized
-- 
tests/providers/postgres/operators/test_postgres.py::TestPostgres::test_overwrite_database
-- 
tests/providers/postgres/operators/test_postgres.py::TestPostgres::test_postgres_operator_test
-- 
tests/providers/postgres/operators/test_postgres.py::TestPostgres::test_postgres_operator_test_multi
-- 
tests/providers/postgres/operators/test_postgres.py::TestPostgres::test_runtime_parameter_setting
-- 
tests/providers/postgres/operators/test_postgres.py::TestPostgres::test_vacuum
-- 
tests/providers/postgres/operators/test_postgres.py::TestPostgresOpenLineage::test_postgres_operator_openlineage_explicit_schema
-- 
tests/providers/postgres/operators/test_postgres.py::TestPostgresOpenLineage::test_postgres_operator_openlineage_implicit_schema
 - 
tests/providers/snowflake/operators/test_snowflake.py::TestSnowflakeOperator::test_snowflake_operator
 - 
tests/providers/snowflake/operators/test_snowflake.py::TestSnowflakeOperatorForParams::test_overwrite_params
 - tests/providers/snowflake/operators/test_snowflake_sql.py::test_exec_success
diff --git a/tests/providers/common/sql/operators/test_sql.py 
b/tests/providers/common/sql/operators/test_sql.py
index e95a018f13..cc8fdaeff8 100644
--- a/tests/providers/common/sql/operators/test_sql.py
+++ b/tests/providers/common/sql/operators/test_sql.py
@@ -24,7 +24,7 @@ from unittest.mock import MagicMock
 import pytest
 
 from airflow import DAG
-from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning
+from airflow.exceptions import AirflowException
 from airflow.models import Connection, DagRun, TaskInstance as TI, XCom
 from airflow.operators.empty import EmptyOperator
 from airflow.providers.common.sql.hooks.sql import fetch_all_handler
@@ -608,14 +608,7 @@ class TestSQLCheckOperatorDbHook:
         ) as mock_get_conn:
             if database:
                 self._operator.database = database
-            if database:
-                with pytest.warns(
-                    AirflowProviderDeprecationWarning,
-                    match='The "schema" variable has been renamed to 
"database" as it contained the database name.Please use "database" to set the 
database name.',
-                ):
-                    assert isinstance(self._operator._hook, PostgresHook)
-            else:
-                assert isinstance(self._operator._hook, PostgresHook)
+            assert isinstance(self._operator._hook, PostgresHook)
             mock_get_conn.assert_called_once_with(self.conn_id)
 
     def test_not_allowed_conn_type(self):
diff --git a/tests/providers/postgres/hooks/test_postgres.py 
b/tests/providers/postgres/hooks/test_postgres.py
index 78d3414ab0..d73311427f 100644
--- a/tests/providers/postgres/hooks/test_postgres.py
+++ b/tests/providers/postgres/hooks/test_postgres.py
@@ -24,6 +24,7 @@ from unittest import mock
 import psycopg2.extras
 import pytest
 
+from airflow.exceptions import AirflowProviderDeprecationWarning
 from airflow.models import Connection
 from airflow.providers.postgres.hooks.postgres import PostgresHook
 from airflow.utils.types import NOTSET
@@ -266,7 +267,11 @@ class TestPostgresHookConn:
 
     def test_schema_kwarg_database_kwarg_compatibility(self):
         database = "database-override"
-        hook = PostgresHook(schema=database)
+        with pytest.warns(
+            AirflowProviderDeprecationWarning,
+            match='The "schema" arg has been renamed to "database" as it 
contained the database name.Please use "database" to set the database name.',
+        ):
+            hook = PostgresHook(schema=database)
         assert hook.database == database
 
     @mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook")
diff --git a/tests/providers/postgres/operators/test_postgres.py 
b/tests/providers/postgres/operators/test_postgres.py
index 0875463196..54831882fe 100644
--- a/tests/providers/postgres/operators/test_postgres.py
+++ b/tests/providers/postgres/operators/test_postgres.py
@@ -20,14 +20,15 @@ from __future__ import annotations
 import pytest
 
 from airflow.models.dag import DAG
+from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator
 from airflow.providers.postgres.hooks.postgres import PostgresHook
-from airflow.providers.postgres.operators.postgres import PostgresOperator
 from airflow.utils import timezone
 
 DEFAULT_DATE = timezone.datetime(2015, 1, 1)
 DEFAULT_DATE_ISO = DEFAULT_DATE.isoformat()
 DEFAULT_DATE_DS = DEFAULT_DATE_ISO[:10]
 TEST_DAG_ID = "unit_test_dag"
+POSTGRES_DEFAULT = "postgres_default"
 
 
 @pytest.mark.backend("postgres")
@@ -51,11 +52,17 @@ class TestPostgres:
             dummy VARCHAR(50)
         );
         """
-        op = PostgresOperator(task_id="basic_postgres", sql=sql, dag=self.dag)
+        op = SQLExecuteQueryOperator(
+            task_id="basic_postgres", sql=sql, dag=self.dag, 
conn_id=POSTGRES_DEFAULT
+        )
         op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, 
ignore_ti_state=True)
 
-        autocommit_task = PostgresOperator(
-            task_id="basic_postgres_with_autocommit", sql=sql, dag=self.dag, 
autocommit=True
+        autocommit_task = SQLExecuteQueryOperator(
+            task_id="basic_postgres_with_autocommit",
+            sql=sql,
+            dag=self.dag,
+            autocommit=True,
+            conn_id=POSTGRES_DEFAULT,
         )
         autocommit_task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, 
ignore_ti_state=True)
 
@@ -65,7 +72,9 @@ class TestPostgres:
             "TRUNCATE TABLE test_airflow",
             "INSERT INTO test_airflow VALUES ('X')",
         ]
-        op = PostgresOperator(task_id="postgres_operator_test_multi", sql=sql, 
dag=self.dag)
+        op = SQLExecuteQueryOperator(
+            task_id="postgres_operator_test_multi", sql=sql, dag=self.dag, 
conn_id=POSTGRES_DEFAULT
+        )
         op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, 
ignore_ti_state=True)
 
     def test_vacuum(self):
@@ -74,7 +83,13 @@ class TestPostgres:
         """
 
         sql = "VACUUM ANALYZE;"
-        op = PostgresOperator(task_id="postgres_operator_test_vacuum", 
sql=sql, dag=self.dag, autocommit=True)
+        op = SQLExecuteQueryOperator(
+            task_id="postgres_operator_test_vacuum",
+            sql=sql,
+            dag=self.dag,
+            autocommit=True,
+            conn_id=POSTGRES_DEFAULT,
+        )
         op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, 
ignore_ti_state=True)
 
     def test_overwrite_database(self):
@@ -83,12 +98,13 @@ class TestPostgres:
         """
 
         sql = "SELECT 1;"
-        op = PostgresOperator(
+        op = SQLExecuteQueryOperator(
             task_id="postgres_operator_test_database_overwrite",
             sql=sql,
             dag=self.dag,
             autocommit=True,
             database="foobar",
+            conn_id=POSTGRES_DEFAULT,
         )
 
         from psycopg2 import OperationalError
@@ -103,11 +119,12 @@ class TestPostgres:
         """
 
         sql = "SELECT 1;"
-        op = PostgresOperator(
+        op = SQLExecuteQueryOperator(
             task_id="postgres_operator_test_runtime_parameter_setting",
             sql=sql,
             dag=self.dag,
-            runtime_parameters={"statement_timeout": "3000ms"},
+            hook_params={"options": "-c statement_timeout=3000ms"},
+            conn_id=POSTGRES_DEFAULT,
         )
         op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, 
ignore_ti_state=True)
         assert op.get_db_hook().get_first("SHOW statement_timeout;")[0] == "3s"
@@ -143,11 +160,12 @@ class TestPostgresOpenLineage:
             dummy VARCHAR(50)
         );
         """
-        op = PostgresOperator(
+        op = SQLExecuteQueryOperator(
             task_id="basic_postgres",
             sql=sql,
             dag=self.dag,
             hook_params={"options": "-c search_path=another_schema"},
+            conn_id=POSTGRES_DEFAULT,
         )
 
         lineage = op.get_openlineage_facets_on_start()
@@ -169,11 +187,12 @@ class TestPostgresOpenLineage:
             dummy VARCHAR(50)
         );
         """
-        op = PostgresOperator(
+        op = SQLExecuteQueryOperator(
             task_id="basic_postgres",
             sql=sql,
             dag=self.dag,
             hook_params={"options": "-c search_path=another_schema"},
+            conn_id=POSTGRES_DEFAULT,
         )
 
         lineage = op.get_openlineage_facets_on_start()
@@ -194,14 +213,16 @@ class TestPostgresOpenLineage:
 def test_parameters_are_templatized(create_task_instance_of_operator):
     """Test that PostgreSQL operator could template the same fields as 
SQLExecuteQueryOperator"""
     ti = create_task_instance_of_operator(
-        PostgresOperator,
-        postgres_conn_id="{{ param.conn_id }}",
+        SQLExecuteQueryOperator,
+        conn_id="{{ param.conn_id }}",
         sql="SELECT * FROM {{ param.table }} WHERE spam = %(spam)s;",
         parameters={"spam": "{{ param.bar }}"},
         dag_id="test-postgres-op-parameters-are-templatized",
         task_id="test-task",
     )
-    task: PostgresOperator = ti.render_templates({"param": {"conn_id": "pg", 
"table": "foo", "bar": "egg"}})
+    task: SQLExecuteQueryOperator = ti.render_templates(
+        {"param": {"conn_id": "pg", "table": "foo", "bar": "egg"}}
+    )
     assert task.conn_id == "pg"
     assert task.sql == "SELECT * FROM foo WHERE spam = %(spam)s;"
     assert task.parameters == {"spam": "egg"}

Reply via email to