This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 57d091ac58 Fix all provider tests for database isolation mode (#41242)
57d091ac58 is described below

commit 57d091ac58eea20bf507f4487c18634f23bcd914
Author: Jarek Potiuk <[email protected]>
AuthorDate: Sat Aug 3 21:15:30 2024 +0200

    Fix all provider tests for database isolation mode (#41242)
---
 tests/conftest.py                                  |  4 ++
 tests/providers/airbyte/hooks/test_airbyte.py      |  3 +-
 .../alibaba/cloud/log/test_oss_task_handler.py     |  4 +-
 .../amazon/aws/log/test_cloudwatch_task_handler.py | 14 +++----
 .../amazon/aws/log/test_s3_task_handler.py         | 17 ++++-----
 .../providers/amazon/aws/operators/test_appflow.py |  4 +-
 tests/providers/apache/druid/hooks/test_druid.py   |  3 ++
 .../providers/apache/druid/operators/test_druid.py |  6 ++-
 .../databricks/operators/test_databricks_copy.py   |  2 +-
 .../providers/google/cloud/links/test_translate.py | 20 +++++++---
 .../google/cloud/log/test_gcs_task_handler.py      |  8 +++-
 .../google/cloud/operators/test_automl.py          | 36 +++++++++++++-----
 .../google/cloud/operators/test_bigquery.py        | 16 ++++++--
 .../test_cloud_storage_transfer_service.py         | 44 ++++++++++++++++------
 .../google/cloud/operators/test_compute.py         | 12 ++++--
 .../google/cloud/operators/test_dataprep.py        | 13 +++++--
 16 files changed, 145 insertions(+), 61 deletions(-)

diff --git a/tests/conftest.py b/tests/conftest.py
index a3d8ed1d3e..ebb25ea6a8 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1192,6 +1192,10 @@ def create_log_template(request):
         session.commit()
 
         def _delete_log_template():
+            from airflow.models import DagRun, TaskInstance
+
+            session.query(TaskInstance).delete()
+            session.query(DagRun).delete()
             session.delete(log_template)
             session.commit()
 
diff --git a/tests/providers/airbyte/hooks/test_airbyte.py 
b/tests/providers/airbyte/hooks/test_airbyte.py
index 6cf211909e..e91227a3a0 100644
--- a/tests/providers/airbyte/hooks/test_airbyte.py
+++ b/tests/providers/airbyte/hooks/test_airbyte.py
@@ -26,7 +26,8 @@ from airflow.models import Connection
 from airflow.providers.airbyte.hooks.airbyte import AirbyteHook
 from airflow.utils import db
 
-pytestmark = pytest.mark.db_test
+# those tests will not work with database isolation because they mock requests
+pytestmark = [pytest.mark.db_test, pytest.mark.skip_if_database_isolation_mode]
 
 
 class TestAirbyteHook:
diff --git a/tests/providers/alibaba/cloud/log/test_oss_task_handler.py 
b/tests/providers/alibaba/cloud/log/test_oss_task_handler.py
index 5bcad1a7a4..2c00c4d022 100644
--- a/tests/providers/alibaba/cloud/log/test_oss_task_handler.py
+++ b/tests/providers/alibaba/cloud/log/test_oss_task_handler.py
@@ -47,7 +47,7 @@ class TestOSSTaskHandler:
         self.oss_task_handler = OSSTaskHandler(self.base_log_folder, 
self.oss_log_folder)
 
     @pytest.fixture(autouse=True)
-    def task_instance(self, create_task_instance):
+    def task_instance(self, create_task_instance, dag_maker):
         self.ti = ti = create_task_instance(
             dag_id="dag_for_testing_oss_task_handler",
             task_id="task_for_testing_oss_task_handler",
@@ -56,6 +56,8 @@ class TestOSSTaskHandler:
         )
         ti.try_number = 1
         ti.raw = False
+        dag_maker.session.merge(ti)
+        dag_maker.session.commit()
         yield
         clear_db_runs()
         clear_db_dags()
diff --git a/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py 
b/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py
index 4f9ecb1664..36a51bbb74 100644
--- a/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py
+++ b/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py
@@ -33,7 +33,6 @@ from airflow.operators.empty import EmptyOperator
 from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
 from airflow.providers.amazon.aws.log.cloudwatch_task_handler import 
CloudwatchTaskHandler
 from airflow.providers.amazon.aws.utils import datetime_to_epoch_utc_ms
-from airflow.utils.session import create_session
 from airflow.utils.state import State
 from airflow.utils.timezone import datetime
 from tests.test_utils.config import conf_vars
@@ -54,7 +53,7 @@ def logmock():
 class TestCloudwatchTaskHandler:
     @conf_vars({("logging", "remote_log_conn_id"): "aws_default"})
     @pytest.fixture(autouse=True)
-    def setup_tests(self, create_log_template, tmp_path_factory):
+    def setup_tests(self, create_log_template, tmp_path_factory, session):
         self.remote_log_group = "log_group_name"
         self.region_name = "us-west-2"
         self.local_log_location = 
str(tmp_path_factory.mktemp("local-cloudwatch-log-location"))
@@ -70,15 +69,16 @@ class TestCloudwatchTaskHandler:
         self.dag = DAG(dag_id=dag_id, start_date=date)
         task = EmptyOperator(task_id=task_id, dag=self.dag)
         dag_run = DagRun(dag_id=self.dag.dag_id, execution_date=date, 
run_id="test", run_type="scheduled")
-        with create_session() as session:
-            session.add(dag_run)
-            session.commit()
-            session.refresh(dag_run)
+        session.add(dag_run)
+        session.commit()
+        session.refresh(dag_run)
 
         self.ti = TaskInstance(task=task, run_id=dag_run.run_id)
         self.ti.dag_run = dag_run
         self.ti.try_number = 1
         self.ti.state = State.RUNNING
+        session.add(self.ti)
+        session.commit()
 
         self.remote_log_stream = 
(f"{dag_id}/{task_id}/{date.isoformat()}/{self.ti.try_number}.log").replace(
             ":", "_"
@@ -88,8 +88,6 @@ class TestCloudwatchTaskHandler:
         yield
 
         self.cloudwatch_task_handler.handler = None
-        with create_session() as session:
-            session.query(DagRun).delete()
 
     def test_hook(self):
         assert isinstance(self.cloudwatch_task_handler.hook, AwsLogsHook)
diff --git a/tests/providers/amazon/aws/log/test_s3_task_handler.py 
b/tests/providers/amazon/aws/log/test_s3_task_handler.py
index dd7a81904c..7bec287105 100644
--- a/tests/providers/amazon/aws/log/test_s3_task_handler.py
+++ b/tests/providers/amazon/aws/log/test_s3_task_handler.py
@@ -31,7 +31,6 @@ from airflow.models import DAG, DagRun, TaskInstance
 from airflow.operators.empty import EmptyOperator
 from airflow.providers.amazon.aws.hooks.s3 import S3Hook
 from airflow.providers.amazon.aws.log.s3_task_handler import S3TaskHandler
-from airflow.utils.session import create_session
 from airflow.utils.state import State, TaskInstanceState
 from airflow.utils.timezone import datetime
 from tests.test_utils.config import conf_vars
@@ -47,7 +46,7 @@ def s3mock():
 class TestS3TaskHandler:
     @conf_vars({("logging", "remote_log_conn_id"): "aws_default"})
     @pytest.fixture(autouse=True)
-    def setup_tests(self, create_log_template, tmp_path_factory):
+    def setup_tests(self, create_log_template, tmp_path_factory, session):
         self.remote_log_base = "s3://bucket/remote/log/location"
         self.remote_log_location = "s3://bucket/remote/log/location/1.log"
         self.remote_log_key = "remote/log/location/1.log"
@@ -61,26 +60,24 @@ class TestS3TaskHandler:
         self.dag = DAG("dag_for_testing_s3_task_handler", start_date=date)
         task = EmptyOperator(task_id="task_for_testing_s3_log_handler", 
dag=self.dag)
         dag_run = DagRun(dag_id=self.dag.dag_id, execution_date=date, 
run_id="test", run_type="manual")
-        with create_session() as session:
-            session.add(dag_run)
-            session.commit()
-            session.refresh(dag_run)
+        session.add(dag_run)
+        session.commit()
+        session.refresh(dag_run)
 
         self.ti = TaskInstance(task=task, run_id=dag_run.run_id)
         self.ti.dag_run = dag_run
         self.ti.try_number = 1
         self.ti.state = State.RUNNING
+        session.add(self.ti)
+        session.commit()
 
         self.conn = boto3.client("s3")
         self.conn.create_bucket(Bucket="bucket")
-
         yield
 
         self.dag.clear()
 
-        with create_session() as session:
-            session.query(DagRun).delete()
-
+        session.query(DagRun).delete()
         if self.s3_task_handler.handler:
             with contextlib.suppress(Exception):
                 os.remove(self.s3_task_handler.handler.baseFilename)
diff --git a/tests/providers/amazon/aws/operators/test_appflow.py 
b/tests/providers/amazon/aws/operators/test_appflow.py
index c49a0ca69c..58f79f6ea3 100644
--- a/tests/providers/amazon/aws/operators/test_appflow.py
+++ b/tests/providers/amazon/aws/operators/test_appflow.py
@@ -54,12 +54,14 @@ AppflowBaseOperator.UPDATE_PROPAGATION_TIME = 0  # avoid 
wait
 
 @pytest.mark.db_test
 @pytest.fixture
-def ctx(create_task_instance):
+def ctx(create_task_instance, session):
     ti = create_task_instance(
         dag_id=DAG_ID,
         task_id=TASK_ID,
         schedule="0 12 * * *",
     )
+    session.add(ti)
+    session.commit()
     return {"task_instance": ti}
 
 
diff --git a/tests/providers/apache/druid/hooks/test_druid.py 
b/tests/providers/apache/druid/hooks/test_druid.py
index 0d42695f06..9befbf37f0 100644
--- a/tests/providers/apache/druid/hooks/test_druid.py
+++ b/tests/providers/apache/druid/hooks/test_druid.py
@@ -26,6 +26,9 @@ from airflow.exceptions import AirflowException
 from airflow.providers.apache.druid.hooks.druid import DruidDbApiHook, 
DruidHook, IngestionType
 
 
+# This test mocks the requests library to avoid making actual HTTP requests so 
database isolation mode
+# will not work for it
[email protected]_if_database_isolation_mode
 @pytest.mark.db_test
 class TestDruidSubmitHook:
     def setup_method(self):
diff --git a/tests/providers/apache/druid/operators/test_druid.py 
b/tests/providers/apache/druid/operators/test_druid.py
index 286cdd3916..9aa85235f7 100644
--- a/tests/providers/apache/druid/operators/test_druid.py
+++ b/tests/providers/apache/druid/operators/test_druid.py
@@ -50,6 +50,7 @@ RENDERED_INDEX = {
 }
 
 
[email protected]_serialized_dag
 @pytest.mark.db_test
 def test_render_template(dag_maker):
     with dag_maker("test_druid_render_template", default_args={"start_date": 
DEFAULT_DATE}):
@@ -59,7 +60,10 @@ def test_render_template(dag_maker):
             params={"index_type": "index_hadoop", "datasource": 
"datasource_prd"},
         )
 
-    
dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED).task_instances[0].render_templates()
+    dag_run = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
+    dag_maker.session.add(dag_run.task_instances[0])
+    dag_maker.session.commit()
+    dag_run.task_instances[0].render_templates()
     assert RENDERED_INDEX == json.loads(operator.json_index_file)
 
 
diff --git a/tests/providers/databricks/operators/test_databricks_copy.py 
b/tests/providers/databricks/operators/test_databricks_copy.py
index a481841868..37970efe35 100644
--- a/tests/providers/databricks/operators/test_databricks_copy.py
+++ b/tests/providers/databricks/operators/test_databricks_copy.py
@@ -232,7 +232,7 @@ def test_incorrect_params_wrong_format():
 
 
 @pytest.mark.db_test
-def test_templating(create_task_instance_of_operator):
+def test_templating(create_task_instance_of_operator, session):
     ti = create_task_instance_of_operator(
         DatabricksCopyIntoOperator,
         # Templated fields
diff --git a/tests/providers/google/cloud/links/test_translate.py 
b/tests/providers/google/cloud/links/test_translate.py
index c0e244acff..82547907a8 100644
--- a/tests/providers/google/cloud/links/test_translate.py
+++ b/tests/providers/google/cloud/links/test_translate.py
@@ -47,7 +47,7 @@ MODEL = "test-model"
 
 class TestTranslationLegacyDatasetLink:
     @pytest.mark.db_test
-    def test_get_link(self, create_task_instance_of_operator):
+    def test_get_link(self, create_task_instance_of_operator, session):
         expected_url = 
f"{TRANSLATION_BASE_LINK}/locations/{GCP_LOCATION}/datasets/{DATASET}/sentences?project={GCP_PROJECT_ID}"
         link = TranslationLegacyDatasetLink()
         ti = create_task_instance_of_operator(
@@ -57,6 +57,8 @@ class TestTranslationLegacyDatasetLink:
             dataset=DATASET,
             location=GCP_LOCATION,
         )
+        session.add(ti)
+        session.commit()
         link.persist(context={"ti": ti}, task_instance=ti.task, 
dataset_id=DATASET, project_id=GCP_PROJECT_ID)
         actual_url = link.get_link(operator=ti.task, ti_key=ti.key)
         assert actual_url == expected_url
@@ -64,7 +66,7 @@ class TestTranslationLegacyDatasetLink:
 
 class TestTranslationDatasetListLink:
     @pytest.mark.db_test
-    def test_get_link(self, create_task_instance_of_operator):
+    def test_get_link(self, create_task_instance_of_operator, session):
         expected_url = 
f"{TRANSLATION_BASE_LINK}/datasets?project={GCP_PROJECT_ID}"
         link = TranslationDatasetListLink()
         ti = create_task_instance_of_operator(
@@ -73,6 +75,8 @@ class TestTranslationDatasetListLink:
             task_id="test_dataset_list_link_task",
             location=GCP_LOCATION,
         )
+        session.add(ti)
+        session.commit()
         link.persist(context={"ti": ti}, task_instance=ti.task, 
project_id=GCP_PROJECT_ID)
         actual_url = link.get_link(operator=ti.task, ti_key=ti.key)
         assert actual_url == expected_url
@@ -80,7 +84,7 @@ class TestTranslationDatasetListLink:
 
 class TestTranslationLegacyModelLink:
     @pytest.mark.db_test
-    def test_get_link(self, create_task_instance_of_operator):
+    def test_get_link(self, create_task_instance_of_operator, session):
         expected_url = (
             
f"{TRANSLATION_BASE_LINK}/locations/{GCP_LOCATION}/datasets/{DATASET}/"
             f"evaluate;modelId={MODEL}?project={GCP_PROJECT_ID}"
@@ -94,6 +98,8 @@ class TestTranslationLegacyModelLink:
             project_id=GCP_PROJECT_ID,
             location=GCP_LOCATION,
         )
+        session.add(ti)
+        session.commit()
         link.persist(
             context={"ti": ti},
             task_instance=ti.task,
@@ -107,7 +113,7 @@ class TestTranslationLegacyModelLink:
 
 class TestTranslationLegacyModelTrainLink:
     @pytest.mark.db_test
-    def test_get_link(self, create_task_instance_of_operator):
+    def test_get_link(self, create_task_instance_of_operator, session):
         expected_url = (
             
f"{TRANSLATION_BASE_LINK}/locations/{GCP_LOCATION}/datasets/{DATASET}/"
             f"train?project={GCP_PROJECT_ID}"
@@ -121,6 +127,8 @@ class TestTranslationLegacyModelTrainLink:
             project_id=GCP_PROJECT_ID,
             location=GCP_LOCATION,
         )
+        session.add(ti)
+        session.commit()
         link.persist(
             context={"ti": ti},
             task_instance=ti.task,
@@ -132,7 +140,7 @@ class TestTranslationLegacyModelTrainLink:
 
 class TestTranslationLegacyModelPredictLink:
     @pytest.mark.db_test
-    def test_get_link(self, create_task_instance_of_operator):
+    def test_get_link(self, create_task_instance_of_operator, session):
         expected_url = (
             
f"{TRANSLATION_BASE_LINK}/locations/{GCP_LOCATION}/datasets/{DATASET}/"
             f"predict;modelId={MODEL}?project={GCP_PROJECT_ID}"
@@ -149,6 +157,8 @@ class TestTranslationLegacyModelPredictLink:
             output_config="input_config",
         )
         ti.task.model = Model(dataset_id=DATASET, display_name=MODEL)
+        session.add(ti)
+        session.commit()
         link.persist(context={"ti": ti}, task_instance=ti.task, 
model_id=MODEL, project_id=GCP_PROJECT_ID)
         actual_url = link.get_link(operator=ti.task, ti_key=ti.key)
         assert actual_url == expected_url
diff --git a/tests/providers/google/cloud/log/test_gcs_task_handler.py 
b/tests/providers/google/cloud/log/test_gcs_task_handler.py
index 1344b2c797..a860e52e15 100644
--- a/tests/providers/google/cloud/log/test_gcs_task_handler.py
+++ b/tests/providers/google/cloud/log/test_gcs_task_handler.py
@@ -34,7 +34,7 @@ from tests.test_utils.db import clear_db_dags, clear_db_runs
 @pytest.mark.db_test
 class TestGCSTaskHandler:
     @pytest.fixture(autouse=True)
-    def task_instance(self, create_task_instance):
+    def task_instance(self, create_task_instance, session):
         self.ti = ti = create_task_instance(
             dag_id="dag_for_testing_gcs_task_handler",
             task_id="task_for_testing_gcs_task_handler",
@@ -43,6 +43,8 @@ class TestGCSTaskHandler:
         )
         ti.try_number = 1
         ti.raw = False
+        session.add(ti)
+        session.commit()
         yield
         clear_db_runs()
         clear_db_dags()
@@ -91,13 +93,15 @@ class TestGCSTaskHandler:
     )
     @mock.patch("google.cloud.storage.Client")
     @mock.patch("google.cloud.storage.Blob")
-    def test_should_read_logs_from_remote(self, mock_blob, mock_client, 
mock_creds):
+    def test_should_read_logs_from_remote(self, mock_blob, mock_client, 
mock_creds, session):
         mock_obj = MagicMock()
         mock_obj.name = "remote/log/location/1.log"
         mock_client.return_value.list_blobs.return_value = [mock_obj]
         mock_blob.from_string.return_value.download_as_bytes.return_value = 
b"CONTENT"
         ti = copy.copy(self.ti)
         ti.state = TaskInstanceState.SUCCESS
+        session.add(ti)
+        session.commit()
         logs, metadata = self.gcs_task_handler._read(ti, self.ti.try_number)
         mock_blob.from_string.assert_called_once_with(
             "gs://bucket/remote/log/location/1.log", mock_client.return_value
diff --git a/tests/providers/google/cloud/operators/test_automl.py 
b/tests/providers/google/cloud/operators/test_automl.py
index 985c5ef0aa..fbe1753753 100644
--- a/tests/providers/google/cloud/operators/test_automl.py
+++ b/tests/providers/google/cloud/operators/test_automl.py
@@ -119,7 +119,7 @@ class TestAutoMLTrainModelOperator:
         mock_hook.assert_not_called()
 
     @pytest.mark.db_test
-    def test_templating(self, create_task_instance_of_operator):
+    def test_templating(self, create_task_instance_of_operator, session):
         ti = create_task_instance_of_operator(
             AutoMLTrainModelOperator,
             # Templated fields
@@ -131,6 +131,8 @@ class TestAutoMLTrainModelOperator:
             task_id="test_template_body_templating_task",
             execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
         )
+        session.add(ti)
+        session.commit()
         ti.render_templates()
         task: AutoMLTrainModelOperator = ti.task
         assert task.model == "model"
@@ -207,7 +209,7 @@ class TestAutoMLBatchPredictOperator:
         mock_hook.return_value.batch_predict.assert_not_called()
 
     @pytest.mark.db_test
-    def test_templating(self, create_task_instance_of_operator):
+    def test_templating(self, create_task_instance_of_operator, session):
         ti = create_task_instance_of_operator(
             AutoMLBatchPredictOperator,
             # Templated fields
@@ -222,6 +224,8 @@ class TestAutoMLBatchPredictOperator:
             task_id="test_template_body_templating_task",
             execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
         )
+        session.add(ti)
+        session.commit()
         ti.render_templates()
         task: AutoMLBatchPredictOperator = ti.task
         assert task.model_id == "model"
@@ -265,7 +269,7 @@ class TestAutoMLPredictOperator:
         )
 
     @pytest.mark.db_test
-    def test_templating(self, create_task_instance_of_operator):
+    def test_templating(self, create_task_instance_of_operator, session):
         ti = create_task_instance_of_operator(
             AutoMLPredictOperator,
             # Templated fields
@@ -279,6 +283,8 @@ class TestAutoMLPredictOperator:
             execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
             payload={},
         )
+        session.add(ti)
+        session.commit()
         ti.render_templates()
         task: AutoMLPredictOperator = ti.task
         assert task.model_id == "model-id"
@@ -372,7 +378,7 @@ class TestAutoMLCreateImportOperator:
         mock_hook.return_value.create_dataset.assert_not_called()
 
     @pytest.mark.db_test
-    def test_templating(self, create_task_instance_of_operator):
+    def test_templating(self, create_task_instance_of_operator, session):
         ti = create_task_instance_of_operator(
             AutoMLCreateDatasetOperator,
             # Templated fields
@@ -385,6 +391,8 @@ class TestAutoMLCreateImportOperator:
             task_id="test_template_body_templating_task",
             execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
         )
+        session.add(ti)
+        session.commit()
         ti.render_templates()
         task: AutoMLCreateDatasetOperator = ti.task
         assert task.dataset == "dataset"
@@ -530,7 +538,7 @@ class TestAutoMLGetModelOperator:
         )
 
     @pytest.mark.db_test
-    def test_templating(self, create_task_instance_of_operator):
+    def test_templating(self, create_task_instance_of_operator, session):
         ti = create_task_instance_of_operator(
             AutoMLGetModelOperator,
             # Templated fields
@@ -543,6 +551,8 @@ class TestAutoMLGetModelOperator:
             task_id="test_template_body_templating_task",
             execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
         )
+        session.add(ti)
+        session.commit()
         ti.render_templates()
         task: AutoMLGetModelOperator = ti.task
         assert task.model_id == "model-id"
@@ -599,7 +609,7 @@ class TestAutoMLDeleteModelOperator:
         mock_hook.return_value.delete_model.assert_not_called()
 
     @pytest.mark.db_test
-    def test_templating(self, create_task_instance_of_operator):
+    def test_templating(self, create_task_instance_of_operator, session):
         ti = create_task_instance_of_operator(
             AutoMLDeleteModelOperator,
             # Templated fields
@@ -612,6 +622,8 @@ class TestAutoMLDeleteModelOperator:
             task_id="test_template_body_templating_task",
             execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
         )
+        session.add(ti)
+        session.commit()
         ti.render_templates()
         task: AutoMLDeleteModelOperator = ti.task
         assert task.model_id == "model-id"
@@ -712,7 +724,7 @@ class TestAutoMLDatasetImportOperator:
         mock_hook.return_value.import_data.assert_not_called()
 
     @pytest.mark.db_test
-    def test_templating(self, create_task_instance_of_operator):
+    def test_templating(self, create_task_instance_of_operator, session):
         ti = create_task_instance_of_operator(
             AutoMLImportDataOperator,
             # Templated fields
@@ -726,6 +738,8 @@ class TestAutoMLDatasetImportOperator:
             task_id="test_template_body_templating_task",
             execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
         )
+        session.add(ti)
+        session.commit()
         ti.render_templates()
         task: AutoMLImportDataOperator = ti.task
         assert task.dataset_id == "dataset-id"
@@ -811,7 +825,7 @@ class TestAutoMLDatasetListOperator:
         )
 
     @pytest.mark.db_test
-    def test_templating(self, create_task_instance_of_operator):
+    def test_templating(self, create_task_instance_of_operator, session):
         ti = create_task_instance_of_operator(
             AutoMLListDatasetOperator,
             # Templated fields
@@ -823,6 +837,8 @@ class TestAutoMLDatasetListOperator:
             task_id="test_template_body_templating_task",
             execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
         )
+        session.add(ti)
+        session.commit()
         ti.render_templates()
         task: AutoMLListDatasetOperator = ti.task
         assert task.location == "location"
@@ -878,7 +894,7 @@ class TestAutoMLDatasetDeleteOperator:
         mock_hook.return_value.delete_dataset.assert_not_called()
 
     @pytest.mark.db_test
-    def test_templating(self, create_task_instance_of_operator):
+    def test_templating(self, create_task_instance_of_operator, session):
         ti = create_task_instance_of_operator(
             AutoMLDeleteDatasetOperator,
             # Templated fields
@@ -891,6 +907,8 @@ class TestAutoMLDatasetDeleteOperator:
             task_id="test_template_body_templating_task",
             execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
         )
+        session.add(ti)
+        session.commit()
         ti.render_templates()
         task: AutoMLDeleteDatasetOperator = ti.task
         assert task.dataset_id == "dataset-id"
diff --git a/tests/providers/google/cloud/operators/test_bigquery.py 
b/tests/providers/google/cloud/operators/test_bigquery.py
index ac314a1ec7..d049986750 100644
--- a/tests/providers/google/cloud/operators/test_bigquery.py
+++ b/tests/providers/google/cloud/operators/test_bigquery.py
@@ -702,7 +702,7 @@ class TestBigQueryOperator:
             operator.execute(MagicMock())
 
     
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
-    def test_bigquery_operator_defaults(self, mock_hook, 
create_task_instance_of_operator):
+    def test_bigquery_operator_defaults(self, mock_hook, 
create_task_instance_of_operator, session):
         with pytest.warns(AirflowProviderDeprecationWarning, 
match=self.deprecation_message):
             ti = create_task_instance_of_operator(
                 BigQueryExecuteQueryOperator,
@@ -711,6 +711,8 @@ class TestBigQueryOperator:
                 sql="Select * from test_table",
                 schema_update_options=None,
             )
+        session.add(ti)
+        session.commit()
         operator = ti.task
 
         operator.execute(MagicMock())
@@ -742,6 +744,7 @@ class TestBigQueryOperator:
         self,
         dag_maker,
         create_task_instance_of_operator,
+        session,
     ):
         with pytest.warns(AirflowProviderDeprecationWarning, 
match=self.deprecation_message):
             ti = create_task_instance_of_operator(
@@ -751,6 +754,8 @@ class TestBigQueryOperator:
                 task_id=TASK_ID,
                 sql="SELECT * FROM test_table",
             )
+        session.add(ti)
+        session.commit()
         serialized_dag = dag_maker.get_serialized_data()
         deserialized_dag = SerializedDAG.deserialize_dag(serialized_dag["dag"])
         assert hasattr(deserialized_dag.tasks[0], "sql")
@@ -840,6 +845,7 @@ class TestBigQueryOperator:
         self,
         mock_hook,
         create_task_instance_of_operator,
+        session,
     ):
         with pytest.warns(AirflowProviderDeprecationWarning, 
match=self.deprecation_message):
             ti = create_task_instance_of_operator(
@@ -850,7 +856,8 @@ class TestBigQueryOperator:
                 sql="SELECT * FROM test_table",
             )
         bigquery_task = ti.task
-
+        session.add(ti)
+        session.commit()
         ti.xcom_push(key="job_id_path", value=TEST_FULL_JOB_ID)
 
         assert (
@@ -860,7 +867,7 @@ class TestBigQueryOperator:
 
     
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
     def test_bigquery_operator_extra_link_when_multiple_query(
-        self, mock_hook, create_task_instance_of_operator
+        self, mock_hook, create_task_instance_of_operator, session
     ):
         with pytest.warns(AirflowProviderDeprecationWarning, 
match=self.deprecation_message):
             ti = create_task_instance_of_operator(
@@ -871,7 +878,8 @@ class TestBigQueryOperator:
                 sql=["SELECT * FROM test_table", "SELECT * FROM test_table2"],
             )
         bigquery_task = ti.task
-
+        session.add(ti)
+        session.commit()
         ti.xcom_push(key="job_id_path", value=[TEST_FULL_JOB_ID, 
TEST_FULL_JOB_ID_2])
 
         assert {"BigQuery Console #1", "BigQuery Console #2"} == 
bigquery_task.operator_extra_link_dict.keys()
diff --git 
a/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py 
b/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py
index 6b6106e47a..73c093a5c0 100644
--- 
a/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py
+++ 
b/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py
@@ -385,7 +385,7 @@ class TestGcpStorageTransferJobCreateOperator:
     @mock.patch(
         
"airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook"
     )
-    def test_templates(self, _, create_task_instance_of_operator, body, 
excepted):
+    def test_templates(self, _, create_task_instance_of_operator, body, 
excepted, session):
         dag_id = "TestGcpStorageTransferJobCreateOperator"
         ti = create_task_instance_of_operator(
             CloudDataTransferServiceCreateJobOperator,
@@ -395,6 +395,8 @@ class TestGcpStorageTransferJobCreateOperator:
             aws_conn_id="{{ dag.dag_id }}",
             task_id="task-id",
         )
+        session.add(ti)
+        session.commit()
         ti.render_templates()
         assert excepted == getattr(ti.task, "body")
         assert dag_id == getattr(ti.task, "gcp_conn_id")
@@ -432,7 +434,7 @@ class TestGcpStorageTransferJobUpdateOperator:
     @mock.patch(
         
"airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook"
     )
-    def test_templates(self, _, create_task_instance_of_operator):
+    def test_templates(self, _, create_task_instance_of_operator, session):
         dag_id = "TestGcpStorageTransferJobUpdateOperator_test_templates"
         ti = create_task_instance_of_operator(
             CloudDataTransferServiceUpdateJobOperator,
@@ -441,6 +443,8 @@ class TestGcpStorageTransferJobUpdateOperator:
             body={"transferJob": {"name": "{{ dag.dag_id }}"}},
             task_id="task-id",
         )
+        session.add(ti)
+        session.commit()
         ti.render_templates()
         assert dag_id == getattr(ti.task, "body")["transferJob"]["name"]
         assert dag_id == getattr(ti.task, "job_name")
@@ -474,7 +478,7 @@ class TestGcpStorageTransferJobDeleteOperator:
     @mock.patch(
         
"airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook"
     )
-    def test_job_delete_with_templates(self, _, 
create_task_instance_of_operator):
+    def test_job_delete_with_templates(self, _, 
create_task_instance_of_operator, session):
         dag_id = "test_job_delete_with_templates"
         ti = create_task_instance_of_operator(
             CloudDataTransferServiceDeleteJobOperator,
@@ -484,6 +488,8 @@ class TestGcpStorageTransferJobDeleteOperator:
             api_version="{{ dag.dag_id }}",
             task_id=TASK_ID,
         )
+        session.add(ti)
+        session.commit()
         ti.render_templates()
         assert dag_id == ti.task.job_name
         assert dag_id == ti.task.gcp_conn_id
@@ -524,7 +530,7 @@ class TestGcpStorageTransferJobRunOperator:
     @mock.patch(
         
"airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook"
     )
-    def test_job_run_with_templates(self, _, create_task_instance_of_operator):
+    def test_job_run_with_templates(self, _, create_task_instance_of_operator, 
session):
         dag_id = "test_job_run_with_templates"
         ti = create_task_instance_of_operator(
             CloudDataTransferServiceRunJobOperator,
@@ -536,6 +542,8 @@ class TestGcpStorageTransferJobRunOperator:
             google_impersonation_chain="{{ dag.dag_id }}",
             task_id=TASK_ID,
         )
+        session.add(ti)
+        session.commit()
         ti.render_templates()
         assert dag_id == ti.task.job_name
         assert dag_id == ti.task.project_id
@@ -576,7 +584,7 @@ class TestGpcStorageTransferOperationsGetOperator:
     @mock.patch(
         
"airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook"
     )
-    def test_operation_get_with_templates(self, _, 
create_task_instance_of_operator):
+    def test_operation_get_with_templates(self, _, 
create_task_instance_of_operator, session):
         dag_id = "test_operation_get_with_templates"
         ti = create_task_instance_of_operator(
             CloudDataTransferServiceGetOperationOperator,
@@ -584,6 +592,8 @@ class TestGpcStorageTransferOperationsGetOperator:
             operation_name="{{ dag.dag_id }}",
             task_id="task-id",
         )
+        session.add(ti)
+        session.commit()
         ti.render_templates()
         assert dag_id == ti.task.operation_name
 
@@ -621,7 +631,7 @@ class TestGcpStorageTransferOperationListOperator:
     @mock.patch(
         
"airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook"
     )
-    def test_templates(self, _, create_task_instance_of_operator):
+    def test_templates(self, _, create_task_instance_of_operator, session):
         dag_id = "TestGcpStorageTransferOperationListOperator_test_templates"
         ti = create_task_instance_of_operator(
             CloudDataTransferServiceListOperationsOperator,
@@ -630,6 +640,8 @@ class TestGcpStorageTransferOperationListOperator:
             gcp_conn_id="{{ dag.dag_id }}",
             task_id="task-id",
         )
+        session.add(ti)
+        session.commit()
         ti.render_templates()
 
         assert dag_id == ti.task.request_filter["job_names"][0]
@@ -661,7 +673,7 @@ class TestGcpStorageTransferOperationsPauseOperator:
     @mock.patch(
         
"airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook"
     )
-    def test_operation_pause_with_templates(self, _, 
create_task_instance_of_operator):
+    def test_operation_pause_with_templates(self, _, 
create_task_instance_of_operator, session):
         dag_id = "test_operation_pause_with_templates"
         ti = create_task_instance_of_operator(
             CloudDataTransferServicePauseOperationOperator,
@@ -671,6 +683,8 @@ class TestGcpStorageTransferOperationsPauseOperator:
             api_version="{{ dag.dag_id }}",
             task_id=TASK_ID,
         )
+        session.add(ti)
+        session.commit()
         ti.render_templates()
         assert dag_id == ti.task.operation_name
         assert dag_id == ti.task.gcp_conn_id
@@ -711,7 +725,7 @@ class TestGcpStorageTransferOperationsResumeOperator:
     @mock.patch(
         
"airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook"
     )
-    def test_operation_resume_with_templates(self, _, 
create_task_instance_of_operator):
+    def test_operation_resume_with_templates(self, _, 
create_task_instance_of_operator, session):
         dag_id = "test_operation_resume_with_templates"
         ti = create_task_instance_of_operator(
             CloudDataTransferServiceResumeOperationOperator,
@@ -721,6 +735,8 @@ class TestGcpStorageTransferOperationsResumeOperator:
             api_version="{{ dag.dag_id }}",
             task_id=TASK_ID,
         )
+        session.add(ti)
+        session.commit()
         ti.render_templates()
         assert dag_id == ti.task.operation_name
         assert dag_id == ti.task.gcp_conn_id
@@ -764,7 +780,7 @@ class TestGcpStorageTransferOperationsCancelOperator:
     @mock.patch(
         
"airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook"
     )
-    def test_operation_cancel_with_templates(self, _, 
create_task_instance_of_operator):
+    def test_operation_cancel_with_templates(self, _, 
create_task_instance_of_operator, session):
         dag_id = "test_operation_cancel_with_templates"
         ti = create_task_instance_of_operator(
             CloudDataTransferServiceCancelOperationOperator,
@@ -774,6 +790,8 @@ class TestGcpStorageTransferOperationsCancelOperator:
             api_version="{{ dag.dag_id }}",
             task_id=TASK_ID,
         )
+        session.add(ti)
+        session.commit()
         ti.render_templates()
         assert dag_id == ti.task.operation_name
         assert dag_id == ti.task.gcp_conn_id
@@ -814,7 +832,7 @@ class TestS3ToGoogleCloudStorageTransferOperator:
     @mock.patch(
         
"airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook"
     )
-    def test_templates(self, _, create_task_instance_of_operator):
+    def test_templates(self, _, create_task_instance_of_operator, session):
         dag_id = "TestS3ToGoogleCloudStorageTransferOperator_test_templates"
         ti = create_task_instance_of_operator(
             CloudDataTransferServiceS3ToGCSOperator,
@@ -826,6 +844,8 @@ class TestS3ToGoogleCloudStorageTransferOperator:
             gcp_conn_id="{{ dag.dag_id }}",
             task_id=TASK_ID,
         )
+        session.add(ti)
+        session.commit()
         ti.render_templates()
         assert dag_id == ti.task.s3_bucket
         assert dag_id == ti.task.gcs_bucket
@@ -962,7 +982,7 @@ class 
TestGoogleCloudStorageToGoogleCloudStorageTransferOperator:
     @mock.patch(
         
"airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook"
     )
-    def test_templates(self, _, create_task_instance_of_operator):
+    def test_templates(self, _, create_task_instance_of_operator, session):
         dag_id = 
"TestGoogleCloudStorageToGoogleCloudStorageTransferOperator_test_templates"
         ti = create_task_instance_of_operator(
             CloudDataTransferServiceGCSToGCSOperator,
@@ -974,6 +994,8 @@ class 
TestGoogleCloudStorageToGoogleCloudStorageTransferOperator:
             gcp_conn_id="{{ dag.dag_id }}",
             task_id=TASK_ID,
         )
+        session.add(ti)
+        session.commit()
         ti.render_templates()
         assert dag_id == ti.task.source_bucket
         assert dag_id == ti.task.destination_bucket
diff --git a/tests/providers/google/cloud/operators/test_compute.py 
b/tests/providers/google/cloud/operators/test_compute.py
index d1dbecbd58..fac74bae48 100644
--- a/tests/providers/google/cloud/operators/test_compute.py
+++ b/tests/providers/google/cloud/operators/test_compute.py
@@ -488,7 +488,7 @@ class TestGceInstanceStart:
     # (could be anything else) just to test if the templating works for all 
fields
     @pytest.mark.db_test
     @mock.patch(COMPUTE_ENGINE_HOOK_PATH)
-    def test_start_instance_with_templates(self, _, 
create_task_instance_of_operator):
+    def test_start_instance_with_templates(self, _, 
create_task_instance_of_operator, session):
         dag_id = "test_instance_start_with_templates"
         ti = create_task_instance_of_operator(
             ComputeEngineStartInstanceOperator,
@@ -500,6 +500,8 @@ class TestGceInstanceStart:
             api_version="{{ dag.dag_id }}",
             task_id="id",
         )
+        session.add(ti)
+        session.commit()
         ti.render_templates()
         assert dag_id == ti.task.project_id
         assert dag_id == ti.task.zone
@@ -579,7 +581,7 @@ class TestGceInstanceStop:
     # (could be anything else) just to test if the templating works for all 
fields
     @pytest.mark.db_test
     @mock.patch(COMPUTE_ENGINE_HOOK_PATH)
-    def test_instance_stop_with_templates(self, _, 
create_task_instance_of_operator):
+    def test_instance_stop_with_templates(self, _, 
create_task_instance_of_operator, session):
         dag_id = "test_instance_stop_with_templates"
         ti = create_task_instance_of_operator(
             ComputeEngineStopInstanceOperator,
@@ -591,6 +593,8 @@ class TestGceInstanceStop:
             api_version="{{ dag.dag_id }}",
             task_id="id",
         )
+        session.add(ti)
+        session.commit()
         ti.render_templates()
         assert dag_id == ti.task.project_id
         assert dag_id == ti.task.zone
@@ -660,7 +664,7 @@ class TestGceInstanceSetMachineType:
     # (could be anything else) just to test if the templating works for all 
fields
     @pytest.mark.db_test
     @mock.patch(COMPUTE_ENGINE_HOOK_PATH)
-    def test_machine_type_set_with_templates(self, _, 
create_task_instance_of_operator):
+    def test_machine_type_set_with_templates(self, _, 
create_task_instance_of_operator, session):
         dag_id = "test_set_machine_type_with_templates"
         ti = create_task_instance_of_operator(
             ComputeEngineSetMachineTypeOperator,
@@ -673,6 +677,8 @@ class TestGceInstanceSetMachineType:
             api_version="{{ dag.dag_id }}",
             task_id="id",
         )
+        session.add(ti)
+        session.commit()
         ti.render_templates()
         assert dag_id == ti.task.project_id
         assert dag_id == ti.task.zone
diff --git a/tests/providers/google/cloud/operators/test_dataprep.py 
b/tests/providers/google/cloud/operators/test_dataprep.py
index acbbf2063b..2377184f6a 100644
--- a/tests/providers/google/cloud/operators/test_dataprep.py
+++ b/tests/providers/google/cloud/operators/test_dataprep.py
@@ -174,7 +174,7 @@ class TestDataprepCopyFlowOperatorTest:
 
     @pytest.mark.db_test
     
@mock.patch("airflow.providers.google.cloud.operators.dataprep.GoogleDataprepHook")
-    def test_execute_with_templated_params(self, _, 
create_task_instance_of_operator):
+    def test_execute_with_templated_params(self, _, 
create_task_instance_of_operator, session):
         dag_id = "test_execute_with_templated_params"
         ti = create_task_instance_of_operator(
             DataprepCopyFlowOperator,
@@ -185,6 +185,8 @@ class TestDataprepCopyFlowOperatorTest:
             name="{{ dag.dag_id }}",
             description="{{ dag.dag_id }}",
         )
+        session.add(ti)
+        session.commit()
         ti.render_templates()
         assert dag_id == ti.task.project_id
         assert dag_id == ti.task.flow_id
@@ -248,7 +250,7 @@ class TestDataprepDeleteFlowOperator:
 
     @pytest.mark.db_test
     
@mock.patch("airflow.providers.google.cloud.operators.dataprep.GoogleDataprepHook")
-    def test_execute_with_template_params(self, _, 
create_task_instance_of_operator):
+    def test_execute_with_template_params(self, _, 
create_task_instance_of_operator, session):
         dag_id = "test_execute_delete_flow_with_template"
         ti = create_task_instance_of_operator(
             DataprepDeleteFlowOperator,
@@ -256,6 +258,8 @@ class TestDataprepDeleteFlowOperator:
             task_id=TASK_ID,
             flow_id="{{ dag.dag_id }}",
         )
+        session.add(ti)
+        session.commit()
         ti.render_templates()
         assert dag_id == ti.task.flow_id
 
@@ -279,7 +283,7 @@ class TestDataprepRunFlowOperator:
 
     @pytest.mark.db_test
     
@mock.patch("airflow.providers.google.cloud.operators.dataprep.GoogleDataprepHook")
-    def test_execute_with_template_params(self, _, 
create_task_instance_of_operator):
+    def test_execute_with_template_params(self, _, 
create_task_instance_of_operator, session):
         dag_id = "test_execute_run_flow_with_template"
         ti = create_task_instance_of_operator(
             DataprepRunFlowOperator,
@@ -289,7 +293,8 @@ class TestDataprepRunFlowOperator:
             flow_id="{{ dag.dag_id }}",
             body_request={},
         )
-
+        session.add(ti)
+        session.commit()
         ti.render_templates()
 
         assert dag_id == ti.task.project_id


Reply via email to