This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 57d091ac58 Fix all provider tests for database isolation mode (#41242)
57d091ac58 is described below
commit 57d091ac58eea20bf507f4487c18634f23bcd914
Author: Jarek Potiuk <[email protected]>
AuthorDate: Sat Aug 3 21:15:30 2024 +0200
Fix all provider tests for database isolation mode (#41242)
---
tests/conftest.py | 4 ++
tests/providers/airbyte/hooks/test_airbyte.py | 3 +-
.../alibaba/cloud/log/test_oss_task_handler.py | 4 +-
.../amazon/aws/log/test_cloudwatch_task_handler.py | 14 +++----
.../amazon/aws/log/test_s3_task_handler.py | 17 ++++-----
.../providers/amazon/aws/operators/test_appflow.py | 4 +-
tests/providers/apache/druid/hooks/test_druid.py | 3 ++
.../providers/apache/druid/operators/test_druid.py | 6 ++-
.../databricks/operators/test_databricks_copy.py | 2 +-
.../providers/google/cloud/links/test_translate.py | 20 +++++++---
.../google/cloud/log/test_gcs_task_handler.py | 8 +++-
.../google/cloud/operators/test_automl.py | 36 +++++++++++++-----
.../google/cloud/operators/test_bigquery.py | 16 ++++++--
.../test_cloud_storage_transfer_service.py | 44 ++++++++++++++++------
.../google/cloud/operators/test_compute.py | 12 ++++--
.../google/cloud/operators/test_dataprep.py | 13 +++++--
16 files changed, 145 insertions(+), 61 deletions(-)
diff --git a/tests/conftest.py b/tests/conftest.py
index a3d8ed1d3e..ebb25ea6a8 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1192,6 +1192,10 @@ def create_log_template(request):
session.commit()
def _delete_log_template():
+ from airflow.models import DagRun, TaskInstance
+
+ session.query(TaskInstance).delete()
+ session.query(DagRun).delete()
session.delete(log_template)
session.commit()
diff --git a/tests/providers/airbyte/hooks/test_airbyte.py
b/tests/providers/airbyte/hooks/test_airbyte.py
index 6cf211909e..e91227a3a0 100644
--- a/tests/providers/airbyte/hooks/test_airbyte.py
+++ b/tests/providers/airbyte/hooks/test_airbyte.py
@@ -26,7 +26,8 @@ from airflow.models import Connection
from airflow.providers.airbyte.hooks.airbyte import AirbyteHook
from airflow.utils import db
-pytestmark = pytest.mark.db_test
+# those tests will not work with database isolation because they mock requests
+pytestmark = [pytest.mark.db_test, pytest.mark.skip_if_database_isolation_mode]
class TestAirbyteHook:
diff --git a/tests/providers/alibaba/cloud/log/test_oss_task_handler.py
b/tests/providers/alibaba/cloud/log/test_oss_task_handler.py
index 5bcad1a7a4..2c00c4d022 100644
--- a/tests/providers/alibaba/cloud/log/test_oss_task_handler.py
+++ b/tests/providers/alibaba/cloud/log/test_oss_task_handler.py
@@ -47,7 +47,7 @@ class TestOSSTaskHandler:
self.oss_task_handler = OSSTaskHandler(self.base_log_folder,
self.oss_log_folder)
@pytest.fixture(autouse=True)
- def task_instance(self, create_task_instance):
+ def task_instance(self, create_task_instance, dag_maker):
self.ti = ti = create_task_instance(
dag_id="dag_for_testing_oss_task_handler",
task_id="task_for_testing_oss_task_handler",
@@ -56,6 +56,8 @@ class TestOSSTaskHandler:
)
ti.try_number = 1
ti.raw = False
+ dag_maker.session.merge(ti)
+ dag_maker.session.commit()
yield
clear_db_runs()
clear_db_dags()
diff --git a/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py
b/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py
index 4f9ecb1664..36a51bbb74 100644
--- a/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py
+++ b/tests/providers/amazon/aws/log/test_cloudwatch_task_handler.py
@@ -33,7 +33,6 @@ from airflow.operators.empty import EmptyOperator
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.providers.amazon.aws.log.cloudwatch_task_handler import
CloudwatchTaskHandler
from airflow.providers.amazon.aws.utils import datetime_to_epoch_utc_ms
-from airflow.utils.session import create_session
from airflow.utils.state import State
from airflow.utils.timezone import datetime
from tests.test_utils.config import conf_vars
@@ -54,7 +53,7 @@ def logmock():
class TestCloudwatchTaskHandler:
@conf_vars({("logging", "remote_log_conn_id"): "aws_default"})
@pytest.fixture(autouse=True)
- def setup_tests(self, create_log_template, tmp_path_factory):
+ def setup_tests(self, create_log_template, tmp_path_factory, session):
self.remote_log_group = "log_group_name"
self.region_name = "us-west-2"
self.local_log_location =
str(tmp_path_factory.mktemp("local-cloudwatch-log-location"))
@@ -70,15 +69,16 @@ class TestCloudwatchTaskHandler:
self.dag = DAG(dag_id=dag_id, start_date=date)
task = EmptyOperator(task_id=task_id, dag=self.dag)
dag_run = DagRun(dag_id=self.dag.dag_id, execution_date=date,
run_id="test", run_type="scheduled")
- with create_session() as session:
- session.add(dag_run)
- session.commit()
- session.refresh(dag_run)
+ session.add(dag_run)
+ session.commit()
+ session.refresh(dag_run)
self.ti = TaskInstance(task=task, run_id=dag_run.run_id)
self.ti.dag_run = dag_run
self.ti.try_number = 1
self.ti.state = State.RUNNING
+ session.add(self.ti)
+ session.commit()
self.remote_log_stream =
(f"{dag_id}/{task_id}/{date.isoformat()}/{self.ti.try_number}.log").replace(
":", "_"
@@ -88,8 +88,6 @@ class TestCloudwatchTaskHandler:
yield
self.cloudwatch_task_handler.handler = None
- with create_session() as session:
- session.query(DagRun).delete()
def test_hook(self):
assert isinstance(self.cloudwatch_task_handler.hook, AwsLogsHook)
diff --git a/tests/providers/amazon/aws/log/test_s3_task_handler.py
b/tests/providers/amazon/aws/log/test_s3_task_handler.py
index dd7a81904c..7bec287105 100644
--- a/tests/providers/amazon/aws/log/test_s3_task_handler.py
+++ b/tests/providers/amazon/aws/log/test_s3_task_handler.py
@@ -31,7 +31,6 @@ from airflow.models import DAG, DagRun, TaskInstance
from airflow.operators.empty import EmptyOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.log.s3_task_handler import S3TaskHandler
-from airflow.utils.session import create_session
from airflow.utils.state import State, TaskInstanceState
from airflow.utils.timezone import datetime
from tests.test_utils.config import conf_vars
@@ -47,7 +46,7 @@ def s3mock():
class TestS3TaskHandler:
@conf_vars({("logging", "remote_log_conn_id"): "aws_default"})
@pytest.fixture(autouse=True)
- def setup_tests(self, create_log_template, tmp_path_factory):
+ def setup_tests(self, create_log_template, tmp_path_factory, session):
self.remote_log_base = "s3://bucket/remote/log/location"
self.remote_log_location = "s3://bucket/remote/log/location/1.log"
self.remote_log_key = "remote/log/location/1.log"
@@ -61,26 +60,24 @@ class TestS3TaskHandler:
self.dag = DAG("dag_for_testing_s3_task_handler", start_date=date)
task = EmptyOperator(task_id="task_for_testing_s3_log_handler",
dag=self.dag)
dag_run = DagRun(dag_id=self.dag.dag_id, execution_date=date,
run_id="test", run_type="manual")
- with create_session() as session:
- session.add(dag_run)
- session.commit()
- session.refresh(dag_run)
+ session.add(dag_run)
+ session.commit()
+ session.refresh(dag_run)
self.ti = TaskInstance(task=task, run_id=dag_run.run_id)
self.ti.dag_run = dag_run
self.ti.try_number = 1
self.ti.state = State.RUNNING
+ session.add(self.ti)
+ session.commit()
self.conn = boto3.client("s3")
self.conn.create_bucket(Bucket="bucket")
-
yield
self.dag.clear()
- with create_session() as session:
- session.query(DagRun).delete()
-
+ session.query(DagRun).delete()
if self.s3_task_handler.handler:
with contextlib.suppress(Exception):
os.remove(self.s3_task_handler.handler.baseFilename)
diff --git a/tests/providers/amazon/aws/operators/test_appflow.py
b/tests/providers/amazon/aws/operators/test_appflow.py
index c49a0ca69c..58f79f6ea3 100644
--- a/tests/providers/amazon/aws/operators/test_appflow.py
+++ b/tests/providers/amazon/aws/operators/test_appflow.py
@@ -54,12 +54,14 @@ AppflowBaseOperator.UPDATE_PROPAGATION_TIME = 0 # avoid
wait
@pytest.mark.db_test
@pytest.fixture
-def ctx(create_task_instance):
+def ctx(create_task_instance, session):
ti = create_task_instance(
dag_id=DAG_ID,
task_id=TASK_ID,
schedule="0 12 * * *",
)
+ session.add(ti)
+ session.commit()
return {"task_instance": ti}
diff --git a/tests/providers/apache/druid/hooks/test_druid.py
b/tests/providers/apache/druid/hooks/test_druid.py
index 0d42695f06..9befbf37f0 100644
--- a/tests/providers/apache/druid/hooks/test_druid.py
+++ b/tests/providers/apache/druid/hooks/test_druid.py
@@ -26,6 +26,9 @@ from airflow.exceptions import AirflowException
from airflow.providers.apache.druid.hooks.druid import DruidDbApiHook,
DruidHook, IngestionType
+# This test mocks the requests library to avoid making actual HTTP requests so
database isolation mode
+# will not work for it
[email protected]_if_database_isolation_mode
@pytest.mark.db_test
class TestDruidSubmitHook:
def setup_method(self):
diff --git a/tests/providers/apache/druid/operators/test_druid.py
b/tests/providers/apache/druid/operators/test_druid.py
index 286cdd3916..9aa85235f7 100644
--- a/tests/providers/apache/druid/operators/test_druid.py
+++ b/tests/providers/apache/druid/operators/test_druid.py
@@ -50,6 +50,7 @@ RENDERED_INDEX = {
}
[email protected]_serialized_dag
@pytest.mark.db_test
def test_render_template(dag_maker):
with dag_maker("test_druid_render_template", default_args={"start_date":
DEFAULT_DATE}):
@@ -59,7 +60,10 @@ def test_render_template(dag_maker):
params={"index_type": "index_hadoop", "datasource":
"datasource_prd"},
)
-
dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED).task_instances[0].render_templates()
+ dag_run = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
+ dag_maker.session.add(dag_run.task_instances[0])
+ dag_maker.session.commit()
+ dag_run.task_instances[0].render_templates()
assert RENDERED_INDEX == json.loads(operator.json_index_file)
diff --git a/tests/providers/databricks/operators/test_databricks_copy.py
b/tests/providers/databricks/operators/test_databricks_copy.py
index a481841868..37970efe35 100644
--- a/tests/providers/databricks/operators/test_databricks_copy.py
+++ b/tests/providers/databricks/operators/test_databricks_copy.py
@@ -232,7 +232,7 @@ def test_incorrect_params_wrong_format():
@pytest.mark.db_test
-def test_templating(create_task_instance_of_operator):
+def test_templating(create_task_instance_of_operator, session):
ti = create_task_instance_of_operator(
DatabricksCopyIntoOperator,
# Templated fields
diff --git a/tests/providers/google/cloud/links/test_translate.py
b/tests/providers/google/cloud/links/test_translate.py
index c0e244acff..82547907a8 100644
--- a/tests/providers/google/cloud/links/test_translate.py
+++ b/tests/providers/google/cloud/links/test_translate.py
@@ -47,7 +47,7 @@ MODEL = "test-model"
class TestTranslationLegacyDatasetLink:
@pytest.mark.db_test
- def test_get_link(self, create_task_instance_of_operator):
+ def test_get_link(self, create_task_instance_of_operator, session):
expected_url =
f"{TRANSLATION_BASE_LINK}/locations/{GCP_LOCATION}/datasets/{DATASET}/sentences?project={GCP_PROJECT_ID}"
link = TranslationLegacyDatasetLink()
ti = create_task_instance_of_operator(
@@ -57,6 +57,8 @@ class TestTranslationLegacyDatasetLink:
dataset=DATASET,
location=GCP_LOCATION,
)
+ session.add(ti)
+ session.commit()
link.persist(context={"ti": ti}, task_instance=ti.task,
dataset_id=DATASET, project_id=GCP_PROJECT_ID)
actual_url = link.get_link(operator=ti.task, ti_key=ti.key)
assert actual_url == expected_url
@@ -64,7 +66,7 @@ class TestTranslationLegacyDatasetLink:
class TestTranslationDatasetListLink:
@pytest.mark.db_test
- def test_get_link(self, create_task_instance_of_operator):
+ def test_get_link(self, create_task_instance_of_operator, session):
expected_url =
f"{TRANSLATION_BASE_LINK}/datasets?project={GCP_PROJECT_ID}"
link = TranslationDatasetListLink()
ti = create_task_instance_of_operator(
@@ -73,6 +75,8 @@ class TestTranslationDatasetListLink:
task_id="test_dataset_list_link_task",
location=GCP_LOCATION,
)
+ session.add(ti)
+ session.commit()
link.persist(context={"ti": ti}, task_instance=ti.task,
project_id=GCP_PROJECT_ID)
actual_url = link.get_link(operator=ti.task, ti_key=ti.key)
assert actual_url == expected_url
@@ -80,7 +84,7 @@ class TestTranslationDatasetListLink:
class TestTranslationLegacyModelLink:
@pytest.mark.db_test
- def test_get_link(self, create_task_instance_of_operator):
+ def test_get_link(self, create_task_instance_of_operator, session):
expected_url = (
f"{TRANSLATION_BASE_LINK}/locations/{GCP_LOCATION}/datasets/{DATASET}/"
f"evaluate;modelId={MODEL}?project={GCP_PROJECT_ID}"
@@ -94,6 +98,8 @@ class TestTranslationLegacyModelLink:
project_id=GCP_PROJECT_ID,
location=GCP_LOCATION,
)
+ session.add(ti)
+ session.commit()
link.persist(
context={"ti": ti},
task_instance=ti.task,
@@ -107,7 +113,7 @@ class TestTranslationLegacyModelLink:
class TestTranslationLegacyModelTrainLink:
@pytest.mark.db_test
- def test_get_link(self, create_task_instance_of_operator):
+ def test_get_link(self, create_task_instance_of_operator, session):
expected_url = (
f"{TRANSLATION_BASE_LINK}/locations/{GCP_LOCATION}/datasets/{DATASET}/"
f"train?project={GCP_PROJECT_ID}"
@@ -121,6 +127,8 @@ class TestTranslationLegacyModelTrainLink:
project_id=GCP_PROJECT_ID,
location=GCP_LOCATION,
)
+ session.add(ti)
+ session.commit()
link.persist(
context={"ti": ti},
task_instance=ti.task,
@@ -132,7 +140,7 @@ class TestTranslationLegacyModelTrainLink:
class TestTranslationLegacyModelPredictLink:
@pytest.mark.db_test
- def test_get_link(self, create_task_instance_of_operator):
+ def test_get_link(self, create_task_instance_of_operator, session):
expected_url = (
f"{TRANSLATION_BASE_LINK}/locations/{GCP_LOCATION}/datasets/{DATASET}/"
f"predict;modelId={MODEL}?project={GCP_PROJECT_ID}"
@@ -149,6 +157,8 @@ class TestTranslationLegacyModelPredictLink:
output_config="input_config",
)
ti.task.model = Model(dataset_id=DATASET, display_name=MODEL)
+ session.add(ti)
+ session.commit()
link.persist(context={"ti": ti}, task_instance=ti.task,
model_id=MODEL, project_id=GCP_PROJECT_ID)
actual_url = link.get_link(operator=ti.task, ti_key=ti.key)
assert actual_url == expected_url
diff --git a/tests/providers/google/cloud/log/test_gcs_task_handler.py
b/tests/providers/google/cloud/log/test_gcs_task_handler.py
index 1344b2c797..a860e52e15 100644
--- a/tests/providers/google/cloud/log/test_gcs_task_handler.py
+++ b/tests/providers/google/cloud/log/test_gcs_task_handler.py
@@ -34,7 +34,7 @@ from tests.test_utils.db import clear_db_dags, clear_db_runs
@pytest.mark.db_test
class TestGCSTaskHandler:
@pytest.fixture(autouse=True)
- def task_instance(self, create_task_instance):
+ def task_instance(self, create_task_instance, session):
self.ti = ti = create_task_instance(
dag_id="dag_for_testing_gcs_task_handler",
task_id="task_for_testing_gcs_task_handler",
@@ -43,6 +43,8 @@ class TestGCSTaskHandler:
)
ti.try_number = 1
ti.raw = False
+ session.add(ti)
+ session.commit()
yield
clear_db_runs()
clear_db_dags()
@@ -91,13 +93,15 @@ class TestGCSTaskHandler:
)
@mock.patch("google.cloud.storage.Client")
@mock.patch("google.cloud.storage.Blob")
- def test_should_read_logs_from_remote(self, mock_blob, mock_client,
mock_creds):
+ def test_should_read_logs_from_remote(self, mock_blob, mock_client,
mock_creds, session):
mock_obj = MagicMock()
mock_obj.name = "remote/log/location/1.log"
mock_client.return_value.list_blobs.return_value = [mock_obj]
mock_blob.from_string.return_value.download_as_bytes.return_value =
b"CONTENT"
ti = copy.copy(self.ti)
ti.state = TaskInstanceState.SUCCESS
+ session.add(ti)
+ session.commit()
logs, metadata = self.gcs_task_handler._read(ti, self.ti.try_number)
mock_blob.from_string.assert_called_once_with(
"gs://bucket/remote/log/location/1.log", mock_client.return_value
diff --git a/tests/providers/google/cloud/operators/test_automl.py
b/tests/providers/google/cloud/operators/test_automl.py
index 985c5ef0aa..fbe1753753 100644
--- a/tests/providers/google/cloud/operators/test_automl.py
+++ b/tests/providers/google/cloud/operators/test_automl.py
@@ -119,7 +119,7 @@ class TestAutoMLTrainModelOperator:
mock_hook.assert_not_called()
@pytest.mark.db_test
- def test_templating(self, create_task_instance_of_operator):
+ def test_templating(self, create_task_instance_of_operator, session):
ti = create_task_instance_of_operator(
AutoMLTrainModelOperator,
# Templated fields
@@ -131,6 +131,8 @@ class TestAutoMLTrainModelOperator:
task_id="test_template_body_templating_task",
execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
)
+ session.add(ti)
+ session.commit()
ti.render_templates()
task: AutoMLTrainModelOperator = ti.task
assert task.model == "model"
@@ -207,7 +209,7 @@ class TestAutoMLBatchPredictOperator:
mock_hook.return_value.batch_predict.assert_not_called()
@pytest.mark.db_test
- def test_templating(self, create_task_instance_of_operator):
+ def test_templating(self, create_task_instance_of_operator, session):
ti = create_task_instance_of_operator(
AutoMLBatchPredictOperator,
# Templated fields
@@ -222,6 +224,8 @@ class TestAutoMLBatchPredictOperator:
task_id="test_template_body_templating_task",
execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
)
+ session.add(ti)
+ session.commit()
ti.render_templates()
task: AutoMLBatchPredictOperator = ti.task
assert task.model_id == "model"
@@ -265,7 +269,7 @@ class TestAutoMLPredictOperator:
)
@pytest.mark.db_test
- def test_templating(self, create_task_instance_of_operator):
+ def test_templating(self, create_task_instance_of_operator, session):
ti = create_task_instance_of_operator(
AutoMLPredictOperator,
# Templated fields
@@ -279,6 +283,8 @@ class TestAutoMLPredictOperator:
execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
payload={},
)
+ session.add(ti)
+ session.commit()
ti.render_templates()
task: AutoMLPredictOperator = ti.task
assert task.model_id == "model-id"
@@ -372,7 +378,7 @@ class TestAutoMLCreateImportOperator:
mock_hook.return_value.create_dataset.assert_not_called()
@pytest.mark.db_test
- def test_templating(self, create_task_instance_of_operator):
+ def test_templating(self, create_task_instance_of_operator, session):
ti = create_task_instance_of_operator(
AutoMLCreateDatasetOperator,
# Templated fields
@@ -385,6 +391,8 @@ class TestAutoMLCreateImportOperator:
task_id="test_template_body_templating_task",
execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
)
+ session.add(ti)
+ session.commit()
ti.render_templates()
task: AutoMLCreateDatasetOperator = ti.task
assert task.dataset == "dataset"
@@ -530,7 +538,7 @@ class TestAutoMLGetModelOperator:
)
@pytest.mark.db_test
- def test_templating(self, create_task_instance_of_operator):
+ def test_templating(self, create_task_instance_of_operator, session):
ti = create_task_instance_of_operator(
AutoMLGetModelOperator,
# Templated fields
@@ -543,6 +551,8 @@ class TestAutoMLGetModelOperator:
task_id="test_template_body_templating_task",
execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
)
+ session.add(ti)
+ session.commit()
ti.render_templates()
task: AutoMLGetModelOperator = ti.task
assert task.model_id == "model-id"
@@ -599,7 +609,7 @@ class TestAutoMLDeleteModelOperator:
mock_hook.return_value.delete_model.assert_not_called()
@pytest.mark.db_test
- def test_templating(self, create_task_instance_of_operator):
+ def test_templating(self, create_task_instance_of_operator, session):
ti = create_task_instance_of_operator(
AutoMLDeleteModelOperator,
# Templated fields
@@ -612,6 +622,8 @@ class TestAutoMLDeleteModelOperator:
task_id="test_template_body_templating_task",
execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
)
+ session.add(ti)
+ session.commit()
ti.render_templates()
task: AutoMLDeleteModelOperator = ti.task
assert task.model_id == "model-id"
@@ -712,7 +724,7 @@ class TestAutoMLDatasetImportOperator:
mock_hook.return_value.import_data.assert_not_called()
@pytest.mark.db_test
- def test_templating(self, create_task_instance_of_operator):
+ def test_templating(self, create_task_instance_of_operator, session):
ti = create_task_instance_of_operator(
AutoMLImportDataOperator,
# Templated fields
@@ -726,6 +738,8 @@ class TestAutoMLDatasetImportOperator:
task_id="test_template_body_templating_task",
execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
)
+ session.add(ti)
+ session.commit()
ti.render_templates()
task: AutoMLImportDataOperator = ti.task
assert task.dataset_id == "dataset-id"
@@ -811,7 +825,7 @@ class TestAutoMLDatasetListOperator:
)
@pytest.mark.db_test
- def test_templating(self, create_task_instance_of_operator):
+ def test_templating(self, create_task_instance_of_operator, session):
ti = create_task_instance_of_operator(
AutoMLListDatasetOperator,
# Templated fields
@@ -823,6 +837,8 @@ class TestAutoMLDatasetListOperator:
task_id="test_template_body_templating_task",
execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
)
+ session.add(ti)
+ session.commit()
ti.render_templates()
task: AutoMLListDatasetOperator = ti.task
assert task.location == "location"
@@ -878,7 +894,7 @@ class TestAutoMLDatasetDeleteOperator:
mock_hook.return_value.delete_dataset.assert_not_called()
@pytest.mark.db_test
- def test_templating(self, create_task_instance_of_operator):
+ def test_templating(self, create_task_instance_of_operator, session):
ti = create_task_instance_of_operator(
AutoMLDeleteDatasetOperator,
# Templated fields
@@ -891,6 +907,8 @@ class TestAutoMLDatasetDeleteOperator:
task_id="test_template_body_templating_task",
execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
)
+ session.add(ti)
+ session.commit()
ti.render_templates()
task: AutoMLDeleteDatasetOperator = ti.task
assert task.dataset_id == "dataset-id"
diff --git a/tests/providers/google/cloud/operators/test_bigquery.py
b/tests/providers/google/cloud/operators/test_bigquery.py
index ac314a1ec7..d049986750 100644
--- a/tests/providers/google/cloud/operators/test_bigquery.py
+++ b/tests/providers/google/cloud/operators/test_bigquery.py
@@ -702,7 +702,7 @@ class TestBigQueryOperator:
operator.execute(MagicMock())
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
- def test_bigquery_operator_defaults(self, mock_hook,
create_task_instance_of_operator):
+ def test_bigquery_operator_defaults(self, mock_hook,
create_task_instance_of_operator, session):
with pytest.warns(AirflowProviderDeprecationWarning,
match=self.deprecation_message):
ti = create_task_instance_of_operator(
BigQueryExecuteQueryOperator,
@@ -711,6 +711,8 @@ class TestBigQueryOperator:
sql="Select * from test_table",
schema_update_options=None,
)
+ session.add(ti)
+ session.commit()
operator = ti.task
operator.execute(MagicMock())
@@ -742,6 +744,7 @@ class TestBigQueryOperator:
self,
dag_maker,
create_task_instance_of_operator,
+ session,
):
with pytest.warns(AirflowProviderDeprecationWarning,
match=self.deprecation_message):
ti = create_task_instance_of_operator(
@@ -751,6 +754,8 @@ class TestBigQueryOperator:
task_id=TASK_ID,
sql="SELECT * FROM test_table",
)
+ session.add(ti)
+ session.commit()
serialized_dag = dag_maker.get_serialized_data()
deserialized_dag = SerializedDAG.deserialize_dag(serialized_dag["dag"])
assert hasattr(deserialized_dag.tasks[0], "sql")
@@ -840,6 +845,7 @@ class TestBigQueryOperator:
self,
mock_hook,
create_task_instance_of_operator,
+ session,
):
with pytest.warns(AirflowProviderDeprecationWarning,
match=self.deprecation_message):
ti = create_task_instance_of_operator(
@@ -850,7 +856,8 @@ class TestBigQueryOperator:
sql="SELECT * FROM test_table",
)
bigquery_task = ti.task
-
+ session.add(ti)
+ session.commit()
ti.xcom_push(key="job_id_path", value=TEST_FULL_JOB_ID)
assert (
@@ -860,7 +867,7 @@ class TestBigQueryOperator:
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
def test_bigquery_operator_extra_link_when_multiple_query(
- self, mock_hook, create_task_instance_of_operator
+ self, mock_hook, create_task_instance_of_operator, session
):
with pytest.warns(AirflowProviderDeprecationWarning,
match=self.deprecation_message):
ti = create_task_instance_of_operator(
@@ -871,7 +878,8 @@ class TestBigQueryOperator:
sql=["SELECT * FROM test_table", "SELECT * FROM test_table2"],
)
bigquery_task = ti.task
-
+ session.add(ti)
+ session.commit()
ti.xcom_push(key="job_id_path", value=[TEST_FULL_JOB_ID,
TEST_FULL_JOB_ID_2])
assert {"BigQuery Console #1", "BigQuery Console #2"} ==
bigquery_task.operator_extra_link_dict.keys()
diff --git
a/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py
b/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py
index 6b6106e47a..73c093a5c0 100644
---
a/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py
+++
b/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py
@@ -385,7 +385,7 @@ class TestGcpStorageTransferJobCreateOperator:
@mock.patch(
"airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook"
)
- def test_templates(self, _, create_task_instance_of_operator, body,
excepted):
+ def test_templates(self, _, create_task_instance_of_operator, body,
excepted, session):
dag_id = "TestGcpStorageTransferJobCreateOperator"
ti = create_task_instance_of_operator(
CloudDataTransferServiceCreateJobOperator,
@@ -395,6 +395,8 @@ class TestGcpStorageTransferJobCreateOperator:
aws_conn_id="{{ dag.dag_id }}",
task_id="task-id",
)
+ session.add(ti)
+ session.commit()
ti.render_templates()
assert excepted == getattr(ti.task, "body")
assert dag_id == getattr(ti.task, "gcp_conn_id")
@@ -432,7 +434,7 @@ class TestGcpStorageTransferJobUpdateOperator:
@mock.patch(
"airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook"
)
- def test_templates(self, _, create_task_instance_of_operator):
+ def test_templates(self, _, create_task_instance_of_operator, session):
dag_id = "TestGcpStorageTransferJobUpdateOperator_test_templates"
ti = create_task_instance_of_operator(
CloudDataTransferServiceUpdateJobOperator,
@@ -441,6 +443,8 @@ class TestGcpStorageTransferJobUpdateOperator:
body={"transferJob": {"name": "{{ dag.dag_id }}"}},
task_id="task-id",
)
+ session.add(ti)
+ session.commit()
ti.render_templates()
assert dag_id == getattr(ti.task, "body")["transferJob"]["name"]
assert dag_id == getattr(ti.task, "job_name")
@@ -474,7 +478,7 @@ class TestGcpStorageTransferJobDeleteOperator:
@mock.patch(
"airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook"
)
- def test_job_delete_with_templates(self, _,
create_task_instance_of_operator):
+ def test_job_delete_with_templates(self, _,
create_task_instance_of_operator, session):
dag_id = "test_job_delete_with_templates"
ti = create_task_instance_of_operator(
CloudDataTransferServiceDeleteJobOperator,
@@ -484,6 +488,8 @@ class TestGcpStorageTransferJobDeleteOperator:
api_version="{{ dag.dag_id }}",
task_id=TASK_ID,
)
+ session.add(ti)
+ session.commit()
ti.render_templates()
assert dag_id == ti.task.job_name
assert dag_id == ti.task.gcp_conn_id
@@ -524,7 +530,7 @@ class TestGcpStorageTransferJobRunOperator:
@mock.patch(
"airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook"
)
- def test_job_run_with_templates(self, _, create_task_instance_of_operator):
+ def test_job_run_with_templates(self, _, create_task_instance_of_operator,
session):
dag_id = "test_job_run_with_templates"
ti = create_task_instance_of_operator(
CloudDataTransferServiceRunJobOperator,
@@ -536,6 +542,8 @@ class TestGcpStorageTransferJobRunOperator:
google_impersonation_chain="{{ dag.dag_id }}",
task_id=TASK_ID,
)
+ session.add(ti)
+ session.commit()
ti.render_templates()
assert dag_id == ti.task.job_name
assert dag_id == ti.task.project_id
@@ -576,7 +584,7 @@ class TestGpcStorageTransferOperationsGetOperator:
@mock.patch(
"airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook"
)
- def test_operation_get_with_templates(self, _,
create_task_instance_of_operator):
+ def test_operation_get_with_templates(self, _,
create_task_instance_of_operator, session):
dag_id = "test_operation_get_with_templates"
ti = create_task_instance_of_operator(
CloudDataTransferServiceGetOperationOperator,
@@ -584,6 +592,8 @@ class TestGpcStorageTransferOperationsGetOperator:
operation_name="{{ dag.dag_id }}",
task_id="task-id",
)
+ session.add(ti)
+ session.commit()
ti.render_templates()
assert dag_id == ti.task.operation_name
@@ -621,7 +631,7 @@ class TestGcpStorageTransferOperationListOperator:
@mock.patch(
"airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook"
)
- def test_templates(self, _, create_task_instance_of_operator):
+ def test_templates(self, _, create_task_instance_of_operator, session):
dag_id = "TestGcpStorageTransferOperationListOperator_test_templates"
ti = create_task_instance_of_operator(
CloudDataTransferServiceListOperationsOperator,
@@ -630,6 +640,8 @@ class TestGcpStorageTransferOperationListOperator:
gcp_conn_id="{{ dag.dag_id }}",
task_id="task-id",
)
+ session.add(ti)
+ session.commit()
ti.render_templates()
assert dag_id == ti.task.request_filter["job_names"][0]
@@ -661,7 +673,7 @@ class TestGcpStorageTransferOperationsPauseOperator:
@mock.patch(
"airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook"
)
- def test_operation_pause_with_templates(self, _,
create_task_instance_of_operator):
+ def test_operation_pause_with_templates(self, _,
create_task_instance_of_operator, session):
dag_id = "test_operation_pause_with_templates"
ti = create_task_instance_of_operator(
CloudDataTransferServicePauseOperationOperator,
@@ -671,6 +683,8 @@ class TestGcpStorageTransferOperationsPauseOperator:
api_version="{{ dag.dag_id }}",
task_id=TASK_ID,
)
+ session.add(ti)
+ session.commit()
ti.render_templates()
assert dag_id == ti.task.operation_name
assert dag_id == ti.task.gcp_conn_id
@@ -711,7 +725,7 @@ class TestGcpStorageTransferOperationsResumeOperator:
@mock.patch(
"airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook"
)
- def test_operation_resume_with_templates(self, _,
create_task_instance_of_operator):
+ def test_operation_resume_with_templates(self, _,
create_task_instance_of_operator, session):
dag_id = "test_operation_resume_with_templates"
ti = create_task_instance_of_operator(
CloudDataTransferServiceResumeOperationOperator,
@@ -721,6 +735,8 @@ class TestGcpStorageTransferOperationsResumeOperator:
api_version="{{ dag.dag_id }}",
task_id=TASK_ID,
)
+ session.add(ti)
+ session.commit()
ti.render_templates()
assert dag_id == ti.task.operation_name
assert dag_id == ti.task.gcp_conn_id
@@ -764,7 +780,7 @@ class TestGcpStorageTransferOperationsCancelOperator:
@mock.patch(
"airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook"
)
- def test_operation_cancel_with_templates(self, _,
create_task_instance_of_operator):
+ def test_operation_cancel_with_templates(self, _,
create_task_instance_of_operator, session):
dag_id = "test_operation_cancel_with_templates"
ti = create_task_instance_of_operator(
CloudDataTransferServiceCancelOperationOperator,
@@ -774,6 +790,8 @@ class TestGcpStorageTransferOperationsCancelOperator:
api_version="{{ dag.dag_id }}",
task_id=TASK_ID,
)
+ session.add(ti)
+ session.commit()
ti.render_templates()
assert dag_id == ti.task.operation_name
assert dag_id == ti.task.gcp_conn_id
@@ -814,7 +832,7 @@ class TestS3ToGoogleCloudStorageTransferOperator:
@mock.patch(
"airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook"
)
- def test_templates(self, _, create_task_instance_of_operator):
+ def test_templates(self, _, create_task_instance_of_operator, session):
dag_id = "TestS3ToGoogleCloudStorageTransferOperator_test_templates"
ti = create_task_instance_of_operator(
CloudDataTransferServiceS3ToGCSOperator,
@@ -826,6 +844,8 @@ class TestS3ToGoogleCloudStorageTransferOperator:
gcp_conn_id="{{ dag.dag_id }}",
task_id=TASK_ID,
)
+ session.add(ti)
+ session.commit()
ti.render_templates()
assert dag_id == ti.task.s3_bucket
assert dag_id == ti.task.gcs_bucket
@@ -962,7 +982,7 @@ class
TestGoogleCloudStorageToGoogleCloudStorageTransferOperator:
@mock.patch(
"airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook"
)
- def test_templates(self, _, create_task_instance_of_operator):
+ def test_templates(self, _, create_task_instance_of_operator, session):
dag_id =
"TestGoogleCloudStorageToGoogleCloudStorageTransferOperator_test_templates"
ti = create_task_instance_of_operator(
CloudDataTransferServiceGCSToGCSOperator,
@@ -974,6 +994,8 @@ class
TestGoogleCloudStorageToGoogleCloudStorageTransferOperator:
gcp_conn_id="{{ dag.dag_id }}",
task_id=TASK_ID,
)
+ session.add(ti)
+ session.commit()
ti.render_templates()
assert dag_id == ti.task.source_bucket
assert dag_id == ti.task.destination_bucket
diff --git a/tests/providers/google/cloud/operators/test_compute.py
b/tests/providers/google/cloud/operators/test_compute.py
index d1dbecbd58..fac74bae48 100644
--- a/tests/providers/google/cloud/operators/test_compute.py
+++ b/tests/providers/google/cloud/operators/test_compute.py
@@ -488,7 +488,7 @@ class TestGceInstanceStart:
# (could be anything else) just to test if the templating works for all
fields
@pytest.mark.db_test
@mock.patch(COMPUTE_ENGINE_HOOK_PATH)
- def test_start_instance_with_templates(self, _,
create_task_instance_of_operator):
+ def test_start_instance_with_templates(self, _,
create_task_instance_of_operator, session):
dag_id = "test_instance_start_with_templates"
ti = create_task_instance_of_operator(
ComputeEngineStartInstanceOperator,
@@ -500,6 +500,8 @@ class TestGceInstanceStart:
api_version="{{ dag.dag_id }}",
task_id="id",
)
+ session.add(ti)
+ session.commit()
ti.render_templates()
assert dag_id == ti.task.project_id
assert dag_id == ti.task.zone
@@ -579,7 +581,7 @@ class TestGceInstanceStop:
# (could be anything else) just to test if the templating works for all
fields
@pytest.mark.db_test
@mock.patch(COMPUTE_ENGINE_HOOK_PATH)
- def test_instance_stop_with_templates(self, _,
create_task_instance_of_operator):
+ def test_instance_stop_with_templates(self, _,
create_task_instance_of_operator, session):
dag_id = "test_instance_stop_with_templates"
ti = create_task_instance_of_operator(
ComputeEngineStopInstanceOperator,
@@ -591,6 +593,8 @@ class TestGceInstanceStop:
api_version="{{ dag.dag_id }}",
task_id="id",
)
+ session.add(ti)
+ session.commit()
ti.render_templates()
assert dag_id == ti.task.project_id
assert dag_id == ti.task.zone
@@ -660,7 +664,7 @@ class TestGceInstanceSetMachineType:
# (could be anything else) just to test if the templating works for all
fields
@pytest.mark.db_test
@mock.patch(COMPUTE_ENGINE_HOOK_PATH)
- def test_machine_type_set_with_templates(self, _,
create_task_instance_of_operator):
+ def test_machine_type_set_with_templates(self, _,
create_task_instance_of_operator, session):
dag_id = "test_set_machine_type_with_templates"
ti = create_task_instance_of_operator(
ComputeEngineSetMachineTypeOperator,
@@ -673,6 +677,8 @@ class TestGceInstanceSetMachineType:
api_version="{{ dag.dag_id }}",
task_id="id",
)
+ session.add(ti)
+ session.commit()
ti.render_templates()
assert dag_id == ti.task.project_id
assert dag_id == ti.task.zone
diff --git a/tests/providers/google/cloud/operators/test_dataprep.py
b/tests/providers/google/cloud/operators/test_dataprep.py
index acbbf2063b..2377184f6a 100644
--- a/tests/providers/google/cloud/operators/test_dataprep.py
+++ b/tests/providers/google/cloud/operators/test_dataprep.py
@@ -174,7 +174,7 @@ class TestDataprepCopyFlowOperatorTest:
@pytest.mark.db_test
@mock.patch("airflow.providers.google.cloud.operators.dataprep.GoogleDataprepHook")
- def test_execute_with_templated_params(self, _,
create_task_instance_of_operator):
+ def test_execute_with_templated_params(self, _,
create_task_instance_of_operator, session):
dag_id = "test_execute_with_templated_params"
ti = create_task_instance_of_operator(
DataprepCopyFlowOperator,
@@ -185,6 +185,8 @@ class TestDataprepCopyFlowOperatorTest:
name="{{ dag.dag_id }}",
description="{{ dag.dag_id }}",
)
+ session.add(ti)
+ session.commit()
ti.render_templates()
assert dag_id == ti.task.project_id
assert dag_id == ti.task.flow_id
@@ -248,7 +250,7 @@ class TestDataprepDeleteFlowOperator:
@pytest.mark.db_test
@mock.patch("airflow.providers.google.cloud.operators.dataprep.GoogleDataprepHook")
- def test_execute_with_template_params(self, _,
create_task_instance_of_operator):
+ def test_execute_with_template_params(self, _,
create_task_instance_of_operator, session):
dag_id = "test_execute_delete_flow_with_template"
ti = create_task_instance_of_operator(
DataprepDeleteFlowOperator,
@@ -256,6 +258,8 @@ class TestDataprepDeleteFlowOperator:
task_id=TASK_ID,
flow_id="{{ dag.dag_id }}",
)
+ session.add(ti)
+ session.commit()
ti.render_templates()
assert dag_id == ti.task.flow_id
@@ -279,7 +283,7 @@ class TestDataprepRunFlowOperator:
@pytest.mark.db_test
@mock.patch("airflow.providers.google.cloud.operators.dataprep.GoogleDataprepHook")
- def test_execute_with_template_params(self, _,
create_task_instance_of_operator):
+ def test_execute_with_template_params(self, _,
create_task_instance_of_operator, session):
dag_id = "test_execute_run_flow_with_template"
ti = create_task_instance_of_operator(
DataprepRunFlowOperator,
@@ -289,7 +293,8 @@ class TestDataprepRunFlowOperator:
flow_id="{{ dag.dag_id }}",
body_request={},
)
-
+ session.add(ti)
+ session.commit()
ti.render_templates()
assert dag_id == ti.task.project_id