This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 0a95c3327af Replace `models.BaseOperator` to Task SDK one for Google
Provider (#52366)
0a95c3327af is described below
commit 0a95c3327af8e7971b1b0b0ce206ab59e2630b40
Author: Kaxil Naik <[email protected]>
AuthorDate: Sat Jun 28 05:33:03 2025 +0530
Replace `models.BaseOperator` to Task SDK one for Google Provider (#52366)
Follow-up of https://github.com/apache/airflow/pull/52292 for Google
provider.
---
.../airflow/providers/google/ads/operators/ads.py | 2 +-
.../providers/google/ads/transfers/ads_to_gcs.py | 2 +-
.../airflow/providers/google/cloud/links/base.py | 21 ++---
.../providers/google/cloud/links/dataproc.py | 28 +++----
.../providers/google/cloud/operators/automl.py | 7 +-
.../google/cloud/operators/bigquery_dts.py | 4 +-
.../providers/google/cloud/operators/cloud_base.py | 2 +-
.../google/cloud/operators/cloud_build.py | 12 +--
.../google/cloud/operators/datacatalog.py | 10 +--
.../providers/google/cloud/operators/dataflow.py | 14 ++--
.../providers/google/cloud/operators/dataplex.py | 15 ++--
.../providers/google/cloud/operators/dataproc.py | 6 +-
.../google/cloud/operators/dataproc_metastore.py | 2 +-
.../providers/google/cloud/operators/functions.py | 2 +-
.../google/cloud/operators/managed_kafka.py | 9 +--
.../providers/google/cloud/operators/translate.py | 6 +-
.../google/cloud/operators/vertex_ai/auto_ml.py | 16 ++--
.../operators/vertex_ai/batch_prediction_job.py | 8 +-
.../google/cloud/operators/vertex_ai/custom_job.py | 28 +++----
.../google/cloud/operators/vertex_ai/dataset.py | 2 +-
.../cloud/operators/vertex_ai/endpoint_service.py | 4 +-
.../cloud/operators/vertex_ai/generative_model.py | 12 +--
.../vertex_ai/hyperparameter_tuning_job.py | 2 +-
.../cloud/operators/vertex_ai/model_service.py | 4 +-
.../cloud/operators/vertex_ai/pipeline_job.py | 2 +-
.../google/cloud/operators/vertex_ai/ray.py | 3 +-
.../providers/google/cloud/operators/workflows.py | 2 +-
.../providers/google/cloud/sensors/bigquery.py | 7 +-
.../providers/google/cloud/sensors/dataflow.py | 7 +-
.../google/cloud/transfers/azure_blob_to_gcs.py | 2 +-
.../cloud/transfers/azure_fileshare_to_gcs.py | 2 +-
.../google/cloud/transfers/bigquery_to_bigquery.py | 2 +-
.../google/cloud/transfers/bigquery_to_gcs.py | 2 +-
.../google/cloud/transfers/bigquery_to_sql.py | 2 +-
.../google/cloud/transfers/calendar_to_gcs.py | 2 +-
.../google/cloud/transfers/cassandra_to_gcs.py | 2 +-
.../google/cloud/transfers/facebook_ads_to_gcs.py | 2 +-
.../google/cloud/transfers/gcs_to_bigquery.py | 2 +-
.../providers/google/cloud/transfers/gcs_to_gcs.py | 2 +-
.../google/cloud/transfers/gcs_to_local.py | 2 +-
.../google/cloud/transfers/gcs_to_sftp.py | 2 +-
.../google/cloud/transfers/gdrive_to_gcs.py | 6 +-
.../google/cloud/transfers/gdrive_to_local.py | 2 +-
.../google/cloud/transfers/http_to_gcs.py | 2 +-
.../google/cloud/transfers/local_to_gcs.py | 2 +-
.../google/cloud/transfers/salesforce_to_gcs.py | 2 +-
.../google/cloud/transfers/sftp_to_gcs.py | 2 +-
.../google/cloud/transfers/sheets_to_gcs.py | 4 +-
.../providers/google/cloud/transfers/sql_to_gcs.py | 2 +-
.../google/firebase/operators/firestore.py | 2 +-
.../providers/google/leveldb/operators/leveldb.py | 2 +-
.../marketing_platform/links/analytics_admin.py | 9 +--
.../operators/analytics_admin.py | 1 -
.../operators/campaign_manager.py | 8 +-
.../marketing_platform/operators/display_video.py | 12 +--
.../marketing_platform/operators/search_ads.py | 2 +-
.../providers/google/suite/operators/sheets.py | 6 +-
.../google/suite/transfers/gcs_to_gdrive.py | 2 +-
.../google/suite/transfers/gcs_to_sheets.py | 2 +-
.../google/suite/transfers/local_to_drive.py | 2 +-
.../src/airflow/providers/google/version_compat.py | 28 +++++++
.../google/cloud/operators/test_cloud_build.py | 2 +-
.../google/cloud/operators/test_datacatalog.py | 71 ++++++++---------
.../unit/google/cloud/operators/test_dataflow.py | 31 ++++----
.../unit/google/cloud/operators/test_dataproc.py | 2 +-
.../unit/google/cloud/operators/test_functions.py | 11 ++-
.../unit/google/cloud/operators/test_translate.py | 36 ++++-----
.../unit/google/cloud/operators/test_vertex_ai.py | 91 ++++++++++++----------
.../google/cloud/transfers/test_sheets_to_gcs.py | 18 +++--
.../tests/unit/google/cloud/utils/airflow_util.py | 2 +-
.../links/test_analytics_admin.py | 4 +-
.../operators/test_campaign_manager.py | 38 ++++-----
.../operators/test_display_video.py | 60 +++++++-------
.../unit/google/suite/operators/test_sheets.py | 13 ++--
74 files changed, 377 insertions(+), 363 deletions(-)
diff --git a/providers/google/src/airflow/providers/google/ads/operators/ads.py
b/providers/google/src/airflow/providers/google/ads/operators/ads.py
index b8ef110c19d..f6aaf1df3fd 100644
--- a/providers/google/src/airflow/providers/google/ads/operators/ads.py
+++ b/providers/google/src/airflow/providers/google/ads/operators/ads.py
@@ -24,9 +24,9 @@ from collections.abc import Sequence
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING
-from airflow.models import BaseOperator
from airflow.providers.google.ads.hooks.ads import GoogleAdsHook
from airflow.providers.google.cloud.hooks.gcs import GCSHook
+from airflow.providers.google.version_compat import BaseOperator
if TYPE_CHECKING:
from airflow.utils.context import Context
diff --git
a/providers/google/src/airflow/providers/google/ads/transfers/ads_to_gcs.py
b/providers/google/src/airflow/providers/google/ads/transfers/ads_to_gcs.py
index 591eaf3b4f0..baa99c5b4aa 100644
--- a/providers/google/src/airflow/providers/google/ads/transfers/ads_to_gcs.py
+++ b/providers/google/src/airflow/providers/google/ads/transfers/ads_to_gcs.py
@@ -22,9 +22,9 @@ from operator import attrgetter
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING
-from airflow.models import BaseOperator
from airflow.providers.google.ads.hooks.ads import GoogleAdsHook
from airflow.providers.google.cloud.hooks.gcs import GCSHook
+from airflow.providers.google.version_compat import BaseOperator
if TYPE_CHECKING:
from airflow.utils.context import Context
diff --git a/providers/google/src/airflow/providers/google/cloud/links/base.py
b/providers/google/src/airflow/providers/google/cloud/links/base.py
index b24b780c613..d543b495ab1 100644
--- a/providers/google/src/airflow/providers/google/cloud/links/base.py
+++ b/providers/google/src/airflow/providers/google/cloud/links/base.py
@@ -19,22 +19,23 @@ from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar
-from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS
-
-if TYPE_CHECKING:
- from airflow.models import BaseOperator
- from airflow.models.taskinstancekey import TaskInstanceKey
- from airflow.providers.google.cloud.operators.cloud_base import
GoogleCloudBaseOperator
- from airflow.sdk import BaseSensorOperator
- from airflow.utils.context import Context
+from airflow.providers.google.version_compat import (
+ AIRFLOW_V_3_0_PLUS,
+ BaseOperator,
+ BaseOperatorLink,
+ BaseSensorOperator,
+)
if AIRFLOW_V_3_0_PLUS:
- from airflow.sdk import BaseOperatorLink
from airflow.sdk.execution_time.xcom import XCom
else:
- from airflow.models.baseoperatorlink import BaseOperatorLink # type:
ignore[no-redef]
from airflow.models.xcom import XCom # type: ignore[no-redef]
+if TYPE_CHECKING:
+ from airflow.models.taskinstancekey import TaskInstanceKey
+ from airflow.providers.google.cloud.operators.cloud_base import
GoogleCloudBaseOperator
+ from airflow.utils.context import Context
+
BASE_LINK = "https://console.cloud.google.com"
diff --git
a/providers/google/src/airflow/providers/google/cloud/links/dataproc.py
b/providers/google/src/airflow/providers/google/cloud/links/dataproc.py
index 832c66bf132..67be746ff4d 100644
--- a/providers/google/src/airflow/providers/google/cloud/links/dataproc.py
+++ b/providers/google/src/airflow/providers/google/cloud/links/dataproc.py
@@ -26,19 +26,20 @@ import attr
from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.google.cloud.links.base import BASE_LINK, BaseGoogleLink
-from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS
+from airflow.providers.google.version_compat import (
+ AIRFLOW_V_3_0_PLUS,
+ BaseOperator,
+ BaseOperatorLink,
+)
if TYPE_CHECKING:
- from airflow.models import BaseOperator
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.utils.context import Context
if AIRFLOW_V_3_0_PLUS:
- from airflow.sdk import BaseOperatorLink
from airflow.sdk.execution_time.xcom import XCom
else:
- from airflow.models import XCom # type: ignore[no-redef]
- from airflow.models.baseoperatorlink import BaseOperatorLink # type:
ignore[no-redef]
+ from airflow.models.xcom import XCom # type: ignore[no-redef]
def __getattr__(name: str) -> Any:
@@ -94,16 +95,16 @@ class DataprocLink(BaseOperatorLink):
@staticmethod
def persist(
context: Context,
- task_instance,
url: str,
resource: str,
+ region: str,
+ project_id: str,
):
- task_instance.xcom_push(
- context=context,
+ context["task_instance"].xcom_push(
key=DataprocLink.key,
value={
- "region": task_instance.region,
- "project_id": task_instance.project_id,
+ "region": region,
+ "project_id": project_id,
"url": url,
"resource": resource,
},
@@ -147,14 +148,13 @@ class DataprocListLink(BaseOperatorLink):
@staticmethod
def persist(
context: Context,
- task_instance,
url: str,
+ project_id: str,
):
- task_instance.xcom_push(
- context=context,
+ context["task_instance"].xcom_push(
key=DataprocListLink.key,
value={
- "project_id": task_instance.project_id,
+ "project_id": project_id,
"url": url,
},
)
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/automl.py
b/providers/google/src/airflow/providers/google/cloud/operators/automl.py
index 648cee403be..7d2cd4ee49f 100644
--- a/providers/google/src/airflow/providers/google/cloud/operators/automl.py
+++ b/providers/google/src/airflow/providers/google/cloud/operators/automl.py
@@ -163,7 +163,7 @@ class AutoMLTrainModelOperator(GoogleCloudBaseOperator):
model_id = hook.extract_object_id(result)
self.log.info("Model is created, model_id: %s", model_id)
- self.xcom_push(context, key="model_id", value=model_id)
+ context["task_instance"].xcom_push(key="model_id", value=model_id)
if project_id:
TranslationLegacyModelLink.persist(
context=context,
@@ -415,7 +415,7 @@ class AutoMLCreateDatasetOperator(GoogleCloudBaseOperator):
dataset_id = hook.extract_object_id(result)
self.log.info("Creating completed. Dataset id: %s", dataset_id)
- self.xcom_push(context, key="dataset_id", value=dataset_id)
+ context["task_instance"].xcom_push(key="dataset_id", value=dataset_id)
project_id = self.project_id or hook.project_id
if project_id:
TranslationLegacyDatasetLink.persist(
@@ -1248,8 +1248,7 @@ class AutoMLListDatasetOperator(GoogleCloudBaseOperator):
result.append(Dataset.to_dict(dataset))
self.log.info("Datasets obtained.")
- self.xcom_push(
- context,
+ context["task_instance"].xcom_push(
key="dataset_id_list",
value=[hook.extract_object_id(d) for d in result],
)
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/bigquery_dts.py
b/providers/google/src/airflow/providers/google/cloud/operators/bigquery_dts.py
index 387a2ab55f3..b06dea7bc02 100644
---
a/providers/google/src/airflow/providers/google/cloud/operators/bigquery_dts.py
+++
b/providers/google/src/airflow/providers/google/cloud/operators/bigquery_dts.py
@@ -141,7 +141,7 @@ class
BigQueryCreateDataTransferOperator(GoogleCloudBaseOperator):
result = TransferConfig.to_dict(response)
self.log.info("Created DTS transfer config %s", get_object_id(result))
- self.xcom_push(context, key="transfer_config_id",
value=get_object_id(result))
+ context["ti"].xcom_push(key="transfer_config_id",
value=get_object_id(result))
# don't push AWS secret in XCOM
result.get("params", {}).pop("secret_access_key", None)
result.get("params", {}).pop("access_key_id", None)
@@ -335,7 +335,7 @@ class
BigQueryDataTransferServiceStartTransferRunsOperator(GoogleCloudBaseOperat
result = StartManualTransferRunsResponse.to_dict(response)
run_id = get_object_id(result["runs"][0])
- self.xcom_push(context, key="run_id", value=run_id)
+ context["ti"].xcom_push(key="run_id", value=run_id)
if not self.deferrable:
# Save as attribute for further use by OpenLineage
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/cloud_base.py
b/providers/google/src/airflow/providers/google/cloud/operators/cloud_base.py
index fb11a7276fb..4e7e2cf6ea9 100644
---
a/providers/google/src/airflow/providers/google/cloud/operators/cloud_base.py
+++
b/providers/google/src/airflow/providers/google/cloud/operators/cloud_base.py
@@ -23,7 +23,7 @@ from typing import Any
from google.api_core.gapic_v1.method import DEFAULT
-from airflow.models import BaseOperator
+from airflow.providers.google.version_compat import BaseOperator
class GoogleCloudBaseOperator(BaseOperator):
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/cloud_build.py
b/providers/google/src/airflow/providers/google/cloud/operators/cloud_build.py
index 78d4bc2d606..26f1444df36 100644
---
a/providers/google/src/airflow/providers/google/cloud/operators/cloud_build.py
+++
b/providers/google/src/airflow/providers/google/cloud/operators/cloud_build.py
@@ -125,7 +125,7 @@ class
CloudBuildCancelBuildOperator(GoogleCloudBaseOperator):
location=self.location,
)
- self.xcom_push(context, key="id", value=result.id)
+ context["task_instance"].xcom_push(key="id", value=result.id)
project_id = self.project_id or hook.project_id
if project_id:
CloudBuildLink.persist(
@@ -235,7 +235,7 @@ class
CloudBuildCreateBuildOperator(GoogleCloudBaseOperator):
metadata=self.metadata,
location=self.location,
)
- self.xcom_push(context, key="id", value=self.id_)
+ context["task_instance"].xcom_push(key="id", value=self.id_)
if not self.wait:
return Build.to_dict(
hook.get_build(id_=self.id_, project_id=self.project_id,
location=self.location)
@@ -358,7 +358,7 @@ class
CloudBuildCreateBuildTriggerOperator(GoogleCloudBaseOperator):
metadata=self.metadata,
location=self.location,
)
- self.xcom_push(context, key="id", value=result.id)
+ context["task_instance"].xcom_push(key="id", value=result.id)
project_id = self.project_id or hook.project_id
if project_id:
CloudBuildTriggerDetailsLink.persist(
@@ -854,7 +854,7 @@ class CloudBuildRetryBuildOperator(GoogleCloudBaseOperator):
location=self.location,
)
- self.xcom_push(context, key="id", value=result.id)
+ context["task_instance"].xcom_push(key="id", value=result.id)
project_id = self.project_id or hook.project_id
if project_id:
CloudBuildLink.persist(
@@ -944,7 +944,7 @@ class
CloudBuildRunBuildTriggerOperator(GoogleCloudBaseOperator):
metadata=self.metadata,
location=self.location,
)
- self.xcom_push(context, key="id", value=result.id)
+ context["task_instance"].xcom_push(key="id", value=result.id)
project_id = self.project_id or hook.project_id
if project_id:
CloudBuildLink.persist(
@@ -1030,7 +1030,7 @@ class
CloudBuildUpdateBuildTriggerOperator(GoogleCloudBaseOperator):
metadata=self.metadata,
location=self.location,
)
- self.xcom_push(context, key="id", value=result.id)
+ context["task_instance"].xcom_push(key="id", value=result.id)
project_id = self.project_id or hook.project_id
if project_id:
CloudBuildTriggerDetailsLink.persist(
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/datacatalog.py
b/providers/google/src/airflow/providers/google/cloud/operators/datacatalog.py
index 48b444f8b13..ba6692cb1e8 100644
---
a/providers/google/src/airflow/providers/google/cloud/operators/datacatalog.py
+++
b/providers/google/src/airflow/providers/google/cloud/operators/datacatalog.py
@@ -163,7 +163,7 @@ class
CloudDataCatalogCreateEntryOperator(GoogleCloudBaseOperator):
)
_, _, entry_id = result.name.rpartition("/")
self.log.info("Current entry_id ID: %s", entry_id)
- self.xcom_push(context, key="entry_id", value=entry_id)
+ context["ti"].xcom_push(key="entry_id", value=entry_id)
DataCatalogEntryLink.persist(
context=context,
entry_id=self.entry_id,
@@ -283,7 +283,7 @@ class
CloudDataCatalogCreateEntryGroupOperator(GoogleCloudBaseOperator):
_, _, entry_group_id = result.name.rpartition("/")
self.log.info("Current entry group ID: %s", entry_group_id)
- self.xcom_push(context, key="entry_group_id", value=entry_group_id)
+ context["ti"].xcom_push(key="entry_group_id", value=entry_group_id)
DataCatalogEntryGroupLink.persist(
context=context,
entry_group_id=self.entry_group_id,
@@ -425,7 +425,7 @@ class
CloudDataCatalogCreateTagOperator(GoogleCloudBaseOperator):
_, _, tag_id = tag.name.rpartition("/")
self.log.info("Current Tag ID: %s", tag_id)
- self.xcom_push(context, key="tag_id", value=tag_id)
+ context["ti"].xcom_push(key="tag_id", value=tag_id)
DataCatalogEntryLink.persist(
context=context,
entry_id=self.entry,
@@ -542,7 +542,7 @@ class
CloudDataCatalogCreateTagTemplateOperator(GoogleCloudBaseOperator):
)
_, _, tag_template = result.name.rpartition("/")
self.log.info("Current Tag ID: %s", tag_template)
- self.xcom_push(context, key="tag_template_id", value=tag_template)
+ context["ti"].xcom_push(key="tag_template_id", value=tag_template)
DataCatalogTagTemplateLink.persist(
context=context,
tag_template_id=self.tag_template_id,
@@ -668,7 +668,7 @@ class
CloudDataCatalogCreateTagTemplateFieldOperator(GoogleCloudBaseOperator):
result = tag_template.fields[self.tag_template_field_id]
self.log.info("Current Tag ID: %s", self.tag_template_field_id)
- self.xcom_push(context, key="tag_template_field_id",
value=self.tag_template_field_id)
+ context["ti"].xcom_push(key="tag_template_field_id",
value=self.tag_template_field_id)
DataCatalogTagTemplateLink.persist(
context=context,
tag_template_id=self.tag_template,
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/dataflow.py
b/providers/google/src/airflow/providers/google/cloud/operators/dataflow.py
index c61a19b6a34..88542a8f1e4 100644
--- a/providers/google/src/airflow/providers/google/cloud/operators/dataflow.py
+++ b/providers/google/src/airflow/providers/google/cloud/operators/dataflow.py
@@ -409,7 +409,7 @@ class
DataflowTemplatedJobStartOperator(GoogleCloudBaseOperator):
append_job_name=self.append_job_name,
)
job_id = self.hook.extract_job_id(self.job)
- self.xcom_push(context, key="job_id", value=job_id)
+ context["task_instance"].xcom_push(key="job_id", value=job_id)
return job_id
self.job = self.hook.launch_job_with_template(
@@ -446,7 +446,7 @@ class
DataflowTemplatedJobStartOperator(GoogleCloudBaseOperator):
raise AirflowException(event["message"])
job_id = event["job_id"]
- self.xcom_push(context, key="job_id", value=job_id)
+ context["task_instance"].xcom_push(key="job_id", value=job_id)
self.log.info("Task %s completed with response %s", self.task_id,
event["message"])
return job_id
@@ -609,7 +609,7 @@ class
DataflowStartFlexTemplateOperator(GoogleCloudBaseOperator):
on_new_job_callback=set_current_job,
)
job_id = self.hook.extract_job_id(self.job)
- self.xcom_push(context, key="job_id", value=job_id)
+ context["task_instance"].xcom_push(key="job_id", value=job_id)
return self.job
self.job = self.hook.launch_job_with_flex_template(
@@ -650,7 +650,7 @@ class
DataflowStartFlexTemplateOperator(GoogleCloudBaseOperator):
job_id = event["job_id"]
self.log.info("Task %s completed with response %s", job_id,
event["message"])
- self.xcom_push(context, key="job_id", value=job_id)
+ context["task_instance"].xcom_push(key="job_id", value=job_id)
job = self.hook.get_job(job_id=job_id, project_id=self.project_id,
location=self.location)
return job
@@ -807,7 +807,7 @@ class DataflowStartYamlJobOperator(GoogleCloudBaseOperator):
raise AirflowException(event["message"])
job = event["job"]
self.log.info("Job %s completed with response %s", job["id"],
event["message"])
- self.xcom_push(context, key="job_id", value=job["id"])
+ context["task_instance"].xcom_push(key="job_id", value=job["id"])
return job
@@ -1025,7 +1025,7 @@ class
DataflowCreatePipelineOperator(GoogleCloudBaseOperator):
location=self.location,
)
DataflowPipelineLink.persist(context=context)
- self.xcom_push(context, key="pipeline_name", value=self.pipeline_name)
+ context["task_instance"].xcom_push(key="pipeline_name",
value=self.pipeline_name)
if self.pipeline:
if "error" in self.pipeline:
raise
AirflowException(self.pipeline.get("error").get("message"))
@@ -1096,7 +1096,7 @@ class
DataflowRunPipelineOperator(GoogleCloudBaseOperator):
location=self.location,
)["job"]
job_id = self.dataflow_hook.extract_job_id(self.job)
- self.xcom_push(context, key="job_id", value=job_id)
+ context["task_instance"].xcom_push(key="job_id", value=job_id)
DataflowJobLink.persist(
context=context, project_id=self.project_id,
region=self.location, job_id=job_id
)
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/dataplex.py
b/providers/google/src/airflow/providers/google/cloud/operators/dataplex.py
index ad6a7217e2c..45361b364ab 100644
--- a/providers/google/src/airflow/providers/google/cloud/operators/dataplex.py
+++ b/providers/google/src/airflow/providers/google/cloud/operators/dataplex.py
@@ -2533,8 +2533,7 @@ class
DataplexCatalogListEntryGroupsOperator(DataplexCatalogBaseOperator):
metadata=self.metadata,
)
self.log.info("EntryGroup on page: %s", entry_group_on_page)
- self.xcom_push(
- context=context,
+ context["ti"].xcom_push(
key="entry_group_page",
value=ListEntryGroupsResponse.to_dict(entry_group_on_page._response),
)
@@ -2954,8 +2953,7 @@ class
DataplexCatalogListEntryTypesOperator(DataplexCatalogBaseOperator):
metadata=self.metadata,
)
self.log.info("EntryType on page: %s", entry_type_on_page)
- self.xcom_push(
- context=context,
+ context["ti"].xcom_push(
key="entry_type_page",
value=ListEntryTypesResponse.to_dict(entry_type_on_page._response),
)
@@ -3308,8 +3306,7 @@ class
DataplexCatalogListAspectTypesOperator(DataplexCatalogBaseOperator):
metadata=self.metadata,
)
self.log.info("AspectType on page: %s", aspect_type_on_page)
- self.xcom_push(
- context=context,
+ context["ti"].xcom_push(
key="aspect_type_page",
value=ListAspectTypesResponse.to_dict(aspect_type_on_page._response),
)
@@ -3803,8 +3800,7 @@ class
DataplexCatalogListEntriesOperator(DataplexCatalogBaseOperator):
metadata=self.metadata,
)
self.log.info("Entries on page: %s", entries_on_page)
- self.xcom_push(
- context=context,
+ context["ti"].xcom_push(
key="entry_page",
value=ListEntriesResponse.to_dict(entries_on_page._response),
)
@@ -3901,8 +3897,7 @@ class
DataplexCatalogSearchEntriesOperator(DataplexCatalogBaseOperator):
metadata=self.metadata,
)
self.log.info("Entries on page: %s", entries_on_page)
- self.xcom_push(
- context=context,
+ context["ti"].xcom_push(
key="entry_page",
value=SearchEntriesResponse.to_dict(entries_on_page._response),
)
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py
b/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py
index 3246c5bb6c3..5ed6923821a 100644
--- a/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py
+++ b/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py
@@ -1353,7 +1353,11 @@ class DataprocJobBaseOperator(GoogleCloudBaseOperator):
self.log.info("Job %s submitted successfully.", job_id)
# Save data required for extra links no matter what the job status
will be
DataprocLink.persist(
- context=context, task_instance=self,
url=DATAPROC_JOB_LINK_DEPRECATED, resource=job_id
+ context=context,
+ url=DATAPROC_JOB_LINK_DEPRECATED,
+ resource=job_id,
+ region=self.region,
+ project_id=self.project_id,
)
if self.deferrable:
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/dataproc_metastore.py
b/providers/google/src/airflow/providers/google/cloud/operators/dataproc_metastore.py
index 01340af99c6..25743a13a70 100644
---
a/providers/google/src/airflow/providers/google/cloud/operators/dataproc_metastore.py
+++
b/providers/google/src/airflow/providers/google/cloud/operators/dataproc_metastore.py
@@ -39,8 +39,8 @@ from airflow.providers.google.common.links.storage import
StorageLink
if TYPE_CHECKING:
from google.protobuf.field_mask_pb2 import FieldMask
- from airflow.models import BaseOperator
from airflow.models.taskinstancekey import TaskInstanceKey
+ from airflow.providers.google.version_compat import BaseOperator
from airflow.utils.context import Context
BASE_LINK = "https://console.cloud.google.com"
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/functions.py
b/providers/google/src/airflow/providers/google/cloud/operators/functions.py
index 77dee5089f2..f31f93c27db 100644
--- a/providers/google/src/airflow/providers/google/cloud/operators/functions.py
+++ b/providers/google/src/airflow/providers/google/cloud/operators/functions.py
@@ -488,7 +488,7 @@ class
CloudFunctionInvokeFunctionOperator(GoogleCloudBaseOperator):
project_id=self.project_id,
)
self.log.info("Function called successfully. Execution id %s",
result.get("executionId"))
- self.xcom_push(context=context, key="execution_id",
value=result.get("executionId"))
+ context["ti"].xcom_push(key="execution_id",
value=result.get("executionId"))
project_id = self.project_id or hook.project_id
if project_id:
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/managed_kafka.py
b/providers/google/src/airflow/providers/google/cloud/operators/managed_kafka.py
index 6c42eb8be1d..c4f7d09e578 100644
---
a/providers/google/src/airflow/providers/google/cloud/operators/managed_kafka.py
+++
b/providers/google/src/airflow/providers/google/cloud/operators/managed_kafka.py
@@ -256,8 +256,7 @@ class
ManagedKafkaListClustersOperator(ManagedKafkaBaseOperator):
timeout=self.timeout,
metadata=self.metadata,
)
- self.xcom_push(
- context=context,
+ context["ti"].xcom_push(
key="cluster_page",
value=types.ListClustersResponse.to_dict(cluster_list_pager._response),
)
@@ -622,8 +621,7 @@ class
ManagedKafkaListTopicsOperator(ManagedKafkaBaseOperator):
timeout=self.timeout,
metadata=self.metadata,
)
- self.xcom_push(
- context=context,
+ context["ti"].xcom_push(
key="topic_page",
value=types.ListTopicsResponse.to_dict(topic_list_pager._response),
)
@@ -897,8 +895,7 @@ class
ManagedKafkaListConsumerGroupsOperator(ManagedKafkaBaseOperator):
timeout=self.timeout,
metadata=self.metadata,
)
- self.xcom_push(
- context=context,
+ context["ti"].xcom_push(
key="consumer_group_page",
value=types.ListConsumerGroupsResponse.to_dict(consumer_group_list_pager._response),
)
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/translate.py
b/providers/google/src/airflow/providers/google/cloud/operators/translate.py
index dd30b4536ae..ba6b0c338ab 100644
--- a/providers/google/src/airflow/providers/google/cloud/operators/translate.py
+++ b/providers/google/src/airflow/providers/google/cloud/operators/translate.py
@@ -479,7 +479,7 @@ class
TranslateCreateDatasetOperator(GoogleCloudBaseOperator):
result = hook.wait_for_operation_result(result_operation)
result = type(result).to_dict(result)
dataset_id = hook.extract_object_id(result)
- self.xcom_push(context, key="dataset_id", value=dataset_id)
+ context["ti"].xcom_push(key="dataset_id", value=dataset_id)
self.log.info("Dataset creation complete. The dataset_id: %s.",
dataset_id)
project_id = self.project_id or hook.project_id
@@ -819,7 +819,7 @@ class TranslateCreateModelOperator(GoogleCloudBaseOperator):
result = hook.wait_for_operation_result(operation=result_operation)
result = type(result).to_dict(result)
model_id = hook.extract_object_id(result)
- self.xcom_push(context, key="model_id", value=model_id)
+ context["ti"].xcom_push(key="model_id", value=model_id)
self.log.info("Model creation complete. The model_id: %s.", model_id)
project_id = self.project_id or hook.project_id
@@ -1406,7 +1406,7 @@ class
TranslateCreateGlossaryOperator(GoogleCloudBaseOperator):
result = type(result).to_dict(result)
glossary_id = hook.extract_object_id(result)
- self.xcom_push(context, key="glossary_id", value=glossary_id)
+ context["ti"].xcom_push(key="glossary_id", value=glossary_id)
self.log.info("Glossary creation complete. The glossary_id: %s.",
glossary_id)
return result
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
index e1a1f50575a..4251b4becaa 100644
---
a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
+++
b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py
@@ -249,11 +249,11 @@ class
CreateAutoMLForecastingTrainingJobOperator(AutoMLTrainingJobBaseOperator):
if model:
result = Model.to_dict(model)
model_id = self.hook.extract_model_id(result)
- self.xcom_push(context, key="model_id", value=model_id)
+ context["ti"].xcom_push(key="model_id", value=model_id)
VertexAIModelLink.persist(context=context, model_id=model_id)
else:
result = model # type: ignore
- self.xcom_push(context, key="training_id", value=training_id)
+ context["ti"].xcom_push(key="training_id", value=training_id)
VertexAITrainingLink.persist(context=context, training_id=training_id)
return result
@@ -341,11 +341,11 @@ class
CreateAutoMLImageTrainingJobOperator(AutoMLTrainingJobBaseOperator):
if model:
result = Model.to_dict(model)
model_id = self.hook.extract_model_id(result)
- self.xcom_push(context, key="model_id", value=model_id)
+ context["ti"].xcom_push(key="model_id", value=model_id)
VertexAIModelLink.persist(context=context, model_id=model_id)
else:
result = model # type: ignore
- self.xcom_push(context, key="training_id", value=training_id)
+ context["ti"].xcom_push(key="training_id", value=training_id)
VertexAITrainingLink.persist(context=context, training_id=training_id)
return result
@@ -464,11 +464,11 @@ class
CreateAutoMLTabularTrainingJobOperator(AutoMLTrainingJobBaseOperator):
if model:
result = Model.to_dict(model)
model_id = self.hook.extract_model_id(result)
- self.xcom_push(context, key="model_id", value=model_id)
+ context["ti"].xcom_push(key="model_id", value=model_id)
VertexAIModelLink.persist(context=context, model_id=model_id)
else:
result = model # type: ignore
- self.xcom_push(context, key="training_id", value=training_id)
+ context["ti"].xcom_push(key="training_id", value=training_id)
VertexAITrainingLink.persist(context=context, training_id=training_id)
return result
@@ -538,11 +538,11 @@ class
CreateAutoMLVideoTrainingJobOperator(AutoMLTrainingJobBaseOperator):
if model:
result = Model.to_dict(model)
model_id = self.hook.extract_model_id(result)
- self.xcom_push(context, key="model_id", value=model_id)
+ context["ti"].xcom_push(key="model_id", value=model_id)
VertexAIModelLink.persist(context=context, model_id=model_id)
else:
result = model # type: ignore
- self.xcom_push(context, key="training_id", value=training_id)
+ context["ti"].xcom_push(key="training_id", value=training_id)
VertexAITrainingLink.persist(context=context, training_id=training_id)
return result
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py
b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py
index 156aefdc6fa..4aad729a9df 100644
---
a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py
+++
b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/batch_prediction_job.py
@@ -269,7 +269,7 @@ class
CreateBatchPredictionJobOperator(GoogleCloudBaseOperator):
batch_prediction_job_id = batch_prediction_job.name
self.log.info("Batch prediction job was created. Job id: %s",
batch_prediction_job_id)
- self.xcom_push(context, key="batch_prediction_job_id",
value=batch_prediction_job_id)
+ context["ti"].xcom_push(key="batch_prediction_job_id",
value=batch_prediction_job_id)
VertexAIBatchPredictionJobLink.persist(
context=context,
batch_prediction_job_id=batch_prediction_job_id,
@@ -303,13 +303,11 @@ class
CreateBatchPredictionJobOperator(GoogleCloudBaseOperator):
job: dict[str, Any] = event["job"]
self.log.info("Batch prediction job %s created and completed
successfully.", job["name"])
job_id = self.hook.extract_batch_prediction_job_id(job)
- self.xcom_push(
- context,
+ context["ti"].xcom_push(
key="batch_prediction_job_id",
value=job_id,
)
- self.xcom_push(
- context,
+ context["ti"].xcom_push(
key="training_conf",
value={
"training_conf_id": job_id,
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py
b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py
index 9543f27c719..24a7efaeadc 100644
---
a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py
+++
b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py
@@ -182,11 +182,11 @@ class
CustomTrainingJobBaseOperator(GoogleCloudBaseOperator):
raise AirflowException(event["message"])
training_pipeline = event["job"]
custom_job_id =
self.hook.extract_custom_job_id_from_training_pipeline(training_pipeline)
- self.xcom_push(context, key="custom_job_id", value=custom_job_id)
+ context["ti"].xcom_push(key="custom_job_id", value=custom_job_id)
try:
model = training_pipeline["model_to_upload"]
model_id = self.hook.extract_model_id(model)
- self.xcom_push(context, key="model_id", value=model_id)
+ context["ti"].xcom_push(key="model_id", value=model_id)
VertexAIModelLink.persist(context=context, model_id=model_id)
return model
except KeyError:
@@ -591,12 +591,12 @@ class
CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
if model:
result = Model.to_dict(model)
model_id = self.hook.extract_model_id(result)
- self.xcom_push(context, key="model_id", value=model_id)
+ context["ti"].xcom_push(key="model_id", value=model_id)
VertexAIModelLink.persist(context=context, model_id=model_id)
else:
result = model # type: ignore
- self.xcom_push(context, key="training_id", value=training_id)
- self.xcom_push(context, key="custom_job_id", value=custom_job_id)
+ context["ti"].xcom_push(key="training_id", value=training_id)
+ context["ti"].xcom_push(key="custom_job_id", value=custom_job_id)
VertexAITrainingLink.persist(context=context, training_id=training_id)
return result
@@ -655,7 +655,7 @@ class
CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
)
custom_container_training_job_obj.wait_for_resource_creation()
training_pipeline_id: str = custom_container_training_job_obj.name
- self.xcom_push(context, key="training_id", value=training_pipeline_id)
+ context["ti"].xcom_push(key="training_id", value=training_pipeline_id)
VertexAITrainingLink.persist(context=context,
training_id=training_pipeline_id)
self.defer(
trigger=CustomContainerTrainingJobTrigger(
@@ -1048,12 +1048,12 @@ class
CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
if model:
result = Model.to_dict(model)
model_id = self.hook.extract_model_id(result)
- self.xcom_push(context, key="model_id", value=model_id)
+ context["ti"].xcom_push(key="model_id", value=model_id)
VertexAIModelLink.persist(context=context, model_id=model_id)
else:
result = model # type: ignore
- self.xcom_push(context, key="training_id", value=training_id)
- self.xcom_push(context, key="custom_job_id", value=custom_job_id)
+ context["ti"].xcom_push(key="training_id", value=training_id)
+ context["ti"].xcom_push(key="custom_job_id", value=custom_job_id)
VertexAITrainingLink.persist(context=context, training_id=training_id)
return result
@@ -1113,7 +1113,7 @@ class
CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
)
custom_python_training_job_obj.wait_for_resource_creation()
training_pipeline_id: str = custom_python_training_job_obj.name
- self.xcom_push(context, key="training_id", value=training_pipeline_id)
+ context["ti"].xcom_push(key="training_id", value=training_pipeline_id)
VertexAITrainingLink.persist(context=context,
training_id=training_pipeline_id)
self.defer(
trigger=CustomPythonPackageTrainingJobTrigger(
@@ -1511,12 +1511,12 @@ class
CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
if model:
result = Model.to_dict(model)
model_id = self.hook.extract_model_id(result)
- self.xcom_push(context, key="model_id", value=model_id)
+ context["ti"].xcom_push(key="model_id", value=model_id)
VertexAIModelLink.persist(context=context, model_id=model_id)
else:
result = model # type: ignore
- self.xcom_push(context, key="training_id", value=training_id)
- self.xcom_push(context, key="custom_job_id", value=custom_job_id)
+ context["ti"].xcom_push(key="training_id", value=training_id)
+ context["ti"].xcom_push(key="custom_job_id", value=custom_job_id)
VertexAITrainingLink.persist(context=context, training_id=training_id)
return result
@@ -1576,7 +1576,7 @@ class
CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
)
custom_training_job_obj.wait_for_resource_creation()
training_pipeline_id: str = custom_training_job_obj.name
- self.xcom_push(context, key="training_id", value=training_pipeline_id)
+ context["ti"].xcom_push(key="training_id", value=training_pipeline_id)
VertexAITrainingLink.persist(context=context,
training_id=training_pipeline_id)
self.defer(
trigger=CustomTrainingJobTrigger(
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/dataset.py
b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/dataset.py
index 594b8e4fcda..a2c0c79eb0f 100644
---
a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/dataset.py
+++
b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/dataset.py
@@ -113,7 +113,7 @@ class CreateDatasetOperator(GoogleCloudBaseOperator):
dataset_id = hook.extract_dataset_id(dataset)
self.log.info("Dataset was created. Dataset id: %s", dataset_id)
- self.xcom_push(context, key="dataset_id", value=dataset_id)
+ context["ti"].xcom_push(key="dataset_id", value=dataset_id)
VertexAIDatasetLink.persist(context=context, dataset_id=dataset_id)
return dataset
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py
b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py
index 32a64e17bca..9871cdccdc4 100644
---
a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py
+++
b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/endpoint_service.py
@@ -122,7 +122,7 @@ class CreateEndpointOperator(GoogleCloudBaseOperator):
endpoint_id = hook.extract_endpoint_id(endpoint)
self.log.info("Endpoint was created. Endpoint ID: %s", endpoint_id)
- self.xcom_push(context, key="endpoint_id", value=endpoint_id)
+ context["ti"].xcom_push(key="endpoint_id", value=endpoint_id)
VertexAIEndpointLink.persist(context=context, endpoint_id=endpoint_id)
return endpoint
@@ -292,7 +292,7 @@ class DeployModelOperator(GoogleCloudBaseOperator):
deployed_model_id = hook.extract_deployed_model_id(deploy_model)
self.log.info("Model was deployed. Deployed Model ID: %s",
deployed_model_id)
- self.xcom_push(context, key="deployed_model_id",
value=deployed_model_id)
+ context["ti"].xcom_push(key="deployed_model_id",
value=deployed_model_id)
VertexAIModelLink.persist(context=context, model_id=deployed_model_id)
return deploy_model
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py
b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py
index 20257dad196..2c47b2ee5af 100644
---
a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py
+++
b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py
@@ -93,7 +93,7 @@ class
TextEmbeddingModelGetEmbeddingsOperator(GoogleCloudBaseOperator):
)
self.log.info("Model response: %s", response)
- self.xcom_push(context, key="model_response", value=response)
+ context["ti"].xcom_push(key="model_response", value=response)
return response
@@ -172,7 +172,7 @@ class
GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
)
self.log.info("Model response: %s", response)
- self.xcom_push(context, key="model_response", value=response)
+ context["ti"].xcom_push(key="model_response", value=response)
return response
@@ -261,8 +261,8 @@ class
SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
self.log.info("Tuned Model Name: %s", response.tuned_model_name)
self.log.info("Tuned Model Endpoint Name: %s",
response.tuned_model_endpoint_name)
- self.xcom_push(context, key="tuned_model_name",
value=response.tuned_model_name)
- self.xcom_push(context, key="tuned_model_endpoint_name",
value=response.tuned_model_endpoint_name)
+ context["ti"].xcom_push(key="tuned_model_name",
value=response.tuned_model_name)
+ context["ti"].xcom_push(key="tuned_model_endpoint_name",
value=response.tuned_model_endpoint_name)
result = {
"tuned_model_name": response.tuned_model_name,
@@ -332,8 +332,8 @@ class CountTokensOperator(GoogleCloudBaseOperator):
self.log.info("Total tokens: %s", response.total_tokens)
self.log.info("Total billable characters: %s",
response.total_billable_characters)
- self.xcom_push(context, key="total_tokens",
value=response.total_tokens)
- self.xcom_push(context, key="total_billable_characters",
value=response.total_billable_characters)
+ context["ti"].xcom_push(key="total_tokens",
value=response.total_tokens)
+ context["ti"].xcom_push(key="total_billable_characters",
value=response.total_billable_characters)
class RunEvaluationOperator(GoogleCloudBaseOperator):
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py
b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py
index 86d278c9ab1..a667778965b 100644
---
a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py
+++
b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/hyperparameter_tuning_job.py
@@ -257,7 +257,7 @@ class
CreateHyperparameterTuningJobOperator(GoogleCloudBaseOperator):
hyperparameter_tuning_job_id = hyperparameter_tuning_job.name
self.log.info("Hyperparameter Tuning job was created. Job id: %s",
hyperparameter_tuning_job_id)
- self.xcom_push(context, key="hyperparameter_tuning_job_id",
value=hyperparameter_tuning_job_id)
+ context["ti"].xcom_push(key="hyperparameter_tuning_job_id",
value=hyperparameter_tuning_job_id)
VertexAITrainingLink.persist(context=context,
training_id=hyperparameter_tuning_job_id)
if self.deferrable:
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/model_service.py
b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/model_service.py
index d5f9c26e5a0..2b6459d4349 100644
---
a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/model_service.py
+++
b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/model_service.py
@@ -186,7 +186,7 @@ class GetModelOperator(GoogleCloudBaseOperator):
)
self.log.info("Model found. Model ID: %s", self.model_id)
- self.xcom_push(context, key="model_id", value=self.model_id)
+ context["ti"].xcom_push(key="model_id", value=self.model_id)
VertexAIModelLink.persist(context=context, model_id=self.model_id)
return Model.to_dict(model)
except NotFound:
@@ -453,7 +453,7 @@ class UploadModelOperator(GoogleCloudBaseOperator):
model_id = hook.extract_model_id(model_resp)
self.log.info("Model was uploaded. Model ID: %s", model_id)
- self.xcom_push(context, key="model_id", value=model_id)
+ context["ti"].xcom_push(key="model_id", value=model_id)
VertexAIModelLink.persist(context=context, model_id=model_id)
return model_resp
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py
b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py
index d12adaaf272..875d2eff00e 100644
---
a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py
+++
b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py
@@ -195,7 +195,7 @@ class RunPipelineJobOperator(GoogleCloudBaseOperator):
)
pipeline_job_id = pipeline_job_obj.job_id
self.log.info("Pipeline job was created. Job id: %s", pipeline_job_id)
- self.xcom_push(context, key="pipeline_job_id", value=pipeline_job_id)
+ context["ti"].xcom_push(key="pipeline_job_id", value=pipeline_job_id)
VertexAIPipelineJobLink.persist(context=context,
pipeline_id=pipeline_job_id)
if self.deferrable:
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/ray.py
b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/ray.py
index 6368e81eb01..0cd222ec6d2 100644
---
a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/ray.py
+++
b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/ray.py
@@ -188,8 +188,7 @@ class CreateRayClusterOperator(RayBaseOperator):
labels=self.labels,
)
cluster_id = self.hook.extract_cluster_id(cluster_path)
- self.xcom_push(
- context=context,
+ context["ti"].xcom_push(
key="cluster_id",
value=cluster_id,
)
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/workflows.py
b/providers/google/src/airflow/providers/google/cloud/operators/workflows.py
index 81f97d18a6d..51da294c664 100644
--- a/providers/google/src/airflow/providers/google/cloud/operators/workflows.py
+++ b/providers/google/src/airflow/providers/google/cloud/operators/workflows.py
@@ -501,7 +501,7 @@ class
WorkflowsCreateExecutionOperator(GoogleCloudBaseOperator):
metadata=self.metadata,
)
execution_id = execution.name.split("/")[-1]
- self.xcom_push(context, key="execution_id", value=execution_id)
+ context["task_instance"].xcom_push(key="execution_id",
value=execution_id)
WorkflowsExecutionLink.persist(
context=context,
diff --git
a/providers/google/src/airflow/providers/google/cloud/sensors/bigquery.py
b/providers/google/src/airflow/providers/google/cloud/sensors/bigquery.py
index 2b82d5cfff1..f08757f0edd 100644
--- a/providers/google/src/airflow/providers/google/cloud/sensors/bigquery.py
+++ b/providers/google/src/airflow/providers/google/cloud/sensors/bigquery.py
@@ -31,12 +31,7 @@ from airflow.providers.google.cloud.triggers.bigquery import
(
BigQueryTableExistenceTrigger,
BigQueryTablePartitionExistenceTrigger,
)
-from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS
-
-if AIRFLOW_V_3_0_PLUS:
- from airflow.sdk import BaseSensorOperator
-else:
- from airflow.sensors.base import BaseSensorOperator # type:
ignore[no-redef]
+from airflow.providers.google.version_compat import BaseSensorOperator
if TYPE_CHECKING:
from airflow.utils.context import Context
diff --git
a/providers/google/src/airflow/providers/google/cloud/sensors/dataflow.py
b/providers/google/src/airflow/providers/google/cloud/sensors/dataflow.py
index 38d131851e2..bc71934e9dc 100644
--- a/providers/google/src/airflow/providers/google/cloud/sensors/dataflow.py
+++ b/providers/google/src/airflow/providers/google/cloud/sensors/dataflow.py
@@ -37,12 +37,7 @@ from airflow.providers.google.cloud.triggers.dataflow import
(
DataflowJobStatusTrigger,
)
from airflow.providers.google.common.hooks.base_google import
PROVIDE_PROJECT_ID
-from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS
-
-if AIRFLOW_V_3_0_PLUS:
- from airflow.sdk import BaseSensorOperator
-else:
- from airflow.sensors.base import BaseSensorOperator # type:
ignore[no-redef]
+from airflow.providers.google.version_compat import BaseSensorOperator
if TYPE_CHECKING:
from airflow.utils.context import Context
diff --git
a/providers/google/src/airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py
b/providers/google/src/airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py
index 303f0c00761..edbfd61b1a2 100644
---
a/providers/google/src/airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py
+++
b/providers/google/src/airflow/providers/google/cloud/transfers/azure_blob_to_gcs.py
@@ -21,8 +21,8 @@ import tempfile
from collections.abc import Sequence
from typing import TYPE_CHECKING
-from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.gcs import GCSHook
+from airflow.providers.google.version_compat import BaseOperator
try:
from airflow.providers.microsoft.azure.hooks.wasb import WasbHook
diff --git
a/providers/google/src/airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py
b/providers/google/src/airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py
index 9dd3ca6015f..4b890b6c59d 100644
---
a/providers/google/src/airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py
+++
b/providers/google/src/airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py
@@ -23,8 +23,8 @@ from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING
from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning
-from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url,
gcs_object_is_directory
+from airflow.providers.google.version_compat import BaseOperator
try:
from airflow.providers.microsoft.azure.hooks.fileshare import
AzureFileShareHook
diff --git
a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py
b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py
index 603fe3929da..d34c71a7dc0 100644
---
a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py
+++
b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py
@@ -22,10 +22,10 @@ from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING
-from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
from airflow.providers.google.cloud.links.bigquery import BigQueryTableLink
from airflow.providers.google.common.hooks.base_google import
PROVIDE_PROJECT_ID
+from airflow.providers.google.version_compat import BaseOperator
if TYPE_CHECKING:
from airflow.utils.context import Context
diff --git
a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py
b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py
index 15318078846..719d565a9fc 100644
---
a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py
+++
b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py
@@ -27,11 +27,11 @@ from google.cloud.bigquery import DEFAULT_RETRY, UnknownJob
from airflow.configuration import conf
from airflow.exceptions import AirflowException
-from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook,
BigQueryJob
from airflow.providers.google.cloud.links.bigquery import BigQueryTableLink
from airflow.providers.google.cloud.triggers.bigquery import
BigQueryInsertJobTrigger
from airflow.providers.google.common.hooks.base_google import
PROVIDE_PROJECT_ID
+from airflow.providers.google.version_compat import BaseOperator
from airflow.utils.helpers import merge_dicts
if TYPE_CHECKING:
diff --git
a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_sql.py
b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_sql.py
index 081bb8df9d6..dc3ad68fb81 100644
---
a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_sql.py
+++
b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_sql.py
@@ -23,9 +23,9 @@ import abc
from collections.abc import Sequence
from typing import TYPE_CHECKING
-from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
from airflow.providers.google.cloud.utils.bigquery_get_data import
bigquery_get_data
+from airflow.providers.google.version_compat import BaseOperator
if TYPE_CHECKING:
from airflow.providers.common.sql.hooks.sql import DbApiHook
diff --git
a/providers/google/src/airflow/providers/google/cloud/transfers/calendar_to_gcs.py
b/providers/google/src/airflow/providers/google/cloud/transfers/calendar_to_gcs.py
index f9d7b977db0..bc6e9628b54 100644
---
a/providers/google/src/airflow/providers/google/cloud/transfers/calendar_to_gcs.py
+++
b/providers/google/src/airflow/providers/google/cloud/transfers/calendar_to_gcs.py
@@ -21,9 +21,9 @@ from collections.abc import Sequence
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Any
-from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.providers.google.suite.hooks.calendar import GoogleCalendarHook
+from airflow.providers.google.version_compat import BaseOperator
if TYPE_CHECKING:
from datetime import datetime
diff --git
a/providers/google/src/airflow/providers/google/cloud/transfers/cassandra_to_gcs.py
b/providers/google/src/airflow/providers/google/cloud/transfers/cassandra_to_gcs.py
index f283649faac..7c40c8fae5d 100644
---
a/providers/google/src/airflow/providers/google/cloud/transfers/cassandra_to_gcs.py
+++
b/providers/google/src/airflow/providers/google/cloud/transfers/cassandra_to_gcs.py
@@ -31,9 +31,9 @@ from uuid import UUID
from cassandra.util import Date, OrderedMapSerializedKey, SortedSet, Time
from airflow.exceptions import AirflowException
-from airflow.models import BaseOperator
from airflow.providers.apache.cassandra.hooks.cassandra import CassandraHook
from airflow.providers.google.cloud.hooks.gcs import GCSHook
+from airflow.providers.google.version_compat import BaseOperator
if TYPE_CHECKING:
from airflow.utils.context import Context
diff --git
a/providers/google/src/airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py
b/providers/google/src/airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py
index ea436670766..8814df2c02d 100644
---
a/providers/google/src/airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py
+++
b/providers/google/src/airflow/providers/google/cloud/transfers/facebook_ads_to_gcs.py
@@ -26,9 +26,9 @@ from enum import Enum
from typing import TYPE_CHECKING, Any
from airflow.exceptions import AirflowException
-from airflow.models import BaseOperator
from airflow.providers.facebook.ads.hooks.ads import FacebookAdsReportingHook
from airflow.providers.google.cloud.hooks.gcs import GCSHook
+from airflow.providers.google.version_compat import BaseOperator
if TYPE_CHECKING:
from facebook_business.adobjects.adsinsights import AdsInsights
diff --git
a/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py
b/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py
index 08a119a4d3a..f4c47dec942 100644
---
a/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py
+++
b/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py
@@ -38,12 +38,12 @@ from google.cloud.bigquery.table import
EncryptionConfiguration, Table, TableRef
from airflow.configuration import conf
from airflow.exceptions import AirflowException
-from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook,
BigQueryJob
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.providers.google.cloud.links.bigquery import BigQueryTableLink
from airflow.providers.google.cloud.triggers.bigquery import
BigQueryInsertJobTrigger
from airflow.providers.google.common.hooks.base_google import
PROVIDE_PROJECT_ID
+from airflow.providers.google.version_compat import BaseOperator
from airflow.utils.helpers import merge_dicts
if TYPE_CHECKING:
diff --git
a/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_gcs.py
b/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_gcs.py
index 54a8269709a..296d216b881 100644
---
a/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_gcs.py
+++
b/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_gcs.py
@@ -24,8 +24,8 @@ from collections.abc import Sequence
from typing import TYPE_CHECKING
from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning
-from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.gcs import GCSHook
+from airflow.providers.google.version_compat import BaseOperator
WILDCARD = "*"
diff --git
a/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_local.py
b/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_local.py
index 70cdf0cdb9b..b407829cc52 100644
---
a/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_local.py
+++
b/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_local.py
@@ -20,9 +20,9 @@ from collections.abc import Sequence
from typing import TYPE_CHECKING
from airflow.exceptions import AirflowException
-from airflow.models import BaseOperator
from airflow.models.xcom import MAX_XCOM_SIZE
from airflow.providers.google.cloud.hooks.gcs import GCSHook
+from airflow.providers.google.version_compat import BaseOperator
if TYPE_CHECKING:
from airflow.utils.context import Context
diff --git
a/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_sftp.py
b/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_sftp.py
index f529cef3613..7aebbe1b685 100644
---
a/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_sftp.py
+++
b/providers/google/src/airflow/providers/google/cloud/transfers/gcs_to_sftp.py
@@ -26,8 +26,8 @@ from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING
from airflow.exceptions import AirflowException
-from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.gcs import GCSHook
+from airflow.providers.google.version_compat import BaseOperator
from airflow.providers.sftp.hooks.sftp import SFTPHook
WILDCARD = "*"
diff --git
a/providers/google/src/airflow/providers/google/cloud/transfers/gdrive_to_gcs.py
b/providers/google/src/airflow/providers/google/cloud/transfers/gdrive_to_gcs.py
index dc57fd712c5..827f7455910 100644
---
a/providers/google/src/airflow/providers/google/cloud/transfers/gdrive_to_gcs.py
+++
b/providers/google/src/airflow/providers/google/cloud/transfers/gdrive_to_gcs.py
@@ -19,9 +19,9 @@ from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING
-from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.providers.google.suite.hooks.drive import GoogleDriveHook
+from airflow.providers.google.version_compat import BaseOperator
if TYPE_CHECKING:
from airflow.utils.context import Context
@@ -99,3 +99,7 @@ class GoogleDriveToGCSOperator(BaseOperator):
bucket_name=self.bucket_name, object_name=self.object_name
) as file:
gdrive_hook.download_file(file_id=file_metadata["id"],
file_handle=file)
+
+ def dry_run(self):
+ """Perform a dry run of the operator."""
+ return None
diff --git
a/providers/google/src/airflow/providers/google/cloud/transfers/gdrive_to_local.py
b/providers/google/src/airflow/providers/google/cloud/transfers/gdrive_to_local.py
index 22dcc9f67cc..12d903ea52f 100644
---
a/providers/google/src/airflow/providers/google/cloud/transfers/gdrive_to_local.py
+++
b/providers/google/src/airflow/providers/google/cloud/transfers/gdrive_to_local.py
@@ -19,8 +19,8 @@ from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING
-from airflow.models import BaseOperator
from airflow.providers.google.suite.hooks.drive import GoogleDriveHook
+from airflow.providers.google.version_compat import BaseOperator
if TYPE_CHECKING:
from airflow.utils.context import Context
diff --git
a/providers/google/src/airflow/providers/google/cloud/transfers/http_to_gcs.py
b/providers/google/src/airflow/providers/google/cloud/transfers/http_to_gcs.py
index 71b2b7020af..b06ac8d22e5 100644
---
a/providers/google/src/airflow/providers/google/cloud/transfers/http_to_gcs.py
+++
b/providers/google/src/airflow/providers/google/cloud/transfers/http_to_gcs.py
@@ -22,8 +22,8 @@ from __future__ import annotations
from functools import cached_property
from typing import TYPE_CHECKING, Any
-from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.gcs import GCSHook
+from airflow.providers.google.version_compat import BaseOperator
from airflow.providers.http.hooks.http import HttpHook
if TYPE_CHECKING:
diff --git
a/providers/google/src/airflow/providers/google/cloud/transfers/local_to_gcs.py
b/providers/google/src/airflow/providers/google/cloud/transfers/local_to_gcs.py
index b1a143b242f..b6c183e6005 100644
---
a/providers/google/src/airflow/providers/google/cloud/transfers/local_to_gcs.py
+++
b/providers/google/src/airflow/providers/google/cloud/transfers/local_to_gcs.py
@@ -24,8 +24,8 @@ from collections.abc import Sequence
from glob import glob
from typing import TYPE_CHECKING
-from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.gcs import GCSHook
+from airflow.providers.google.version_compat import BaseOperator
if TYPE_CHECKING:
from airflow.utils.context import Context
diff --git
a/providers/google/src/airflow/providers/google/cloud/transfers/salesforce_to_gcs.py
b/providers/google/src/airflow/providers/google/cloud/transfers/salesforce_to_gcs.py
index b83f09be70b..5ffe7f59f32 100644
---
a/providers/google/src/airflow/providers/google/cloud/transfers/salesforce_to_gcs.py
+++
b/providers/google/src/airflow/providers/google/cloud/transfers/salesforce_to_gcs.py
@@ -21,8 +21,8 @@ import tempfile
from collections.abc import Sequence
from typing import TYPE_CHECKING
-from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.gcs import GCSHook
+from airflow.providers.google.version_compat import BaseOperator
from airflow.providers.salesforce.hooks.salesforce import SalesforceHook
if TYPE_CHECKING:
diff --git
a/providers/google/src/airflow/providers/google/cloud/transfers/sftp_to_gcs.py
b/providers/google/src/airflow/providers/google/cloud/transfers/sftp_to_gcs.py
index 1653b84ec2b..9e53d16f943 100644
---
a/providers/google/src/airflow/providers/google/cloud/transfers/sftp_to_gcs.py
+++
b/providers/google/src/airflow/providers/google/cloud/transfers/sftp_to_gcs.py
@@ -26,8 +26,8 @@ from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING
from airflow.exceptions import AirflowException
-from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.gcs import GCSHook
+from airflow.providers.google.version_compat import BaseOperator
from airflow.providers.sftp.hooks.sftp import SFTPHook
if TYPE_CHECKING:
diff --git
a/providers/google/src/airflow/providers/google/cloud/transfers/sheets_to_gcs.py
b/providers/google/src/airflow/providers/google/cloud/transfers/sheets_to_gcs.py
index 8c681bc5c9d..6f11b38531f 100644
---
a/providers/google/src/airflow/providers/google/cloud/transfers/sheets_to_gcs.py
+++
b/providers/google/src/airflow/providers/google/cloud/transfers/sheets_to_gcs.py
@@ -21,9 +21,9 @@ from collections.abc import Sequence
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Any
-from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.providers.google.suite.hooks.sheets import GSheetsHook
+from airflow.providers.google.version_compat import BaseOperator
if TYPE_CHECKING:
from airflow.utils.context import Context
@@ -130,5 +130,5 @@ class GoogleSheetsToGCSOperator(BaseOperator):
gcs_path_to_file = self._upload_data(gcs_hook, sheet_hook,
sheet_range, data)
destination_array.append(gcs_path_to_file)
- self.xcom_push(context, "destination_objects", destination_array)
+ context["ti"].xcom_push(key="destination_objects",
value=destination_array)
return destination_array
diff --git
a/providers/google/src/airflow/providers/google/cloud/transfers/sql_to_gcs.py
b/providers/google/src/airflow/providers/google/cloud/transfers/sql_to_gcs.py
index 6953c708767..e4ec2d730ef 100644
---
a/providers/google/src/airflow/providers/google/cloud/transfers/sql_to_gcs.py
+++
b/providers/google/src/airflow/providers/google/cloud/transfers/sql_to_gcs.py
@@ -30,8 +30,8 @@ from typing import TYPE_CHECKING, Any
import pyarrow as pa
import pyarrow.parquet as pq
-from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.gcs import GCSHook
+from airflow.providers.google.version_compat import BaseOperator
if TYPE_CHECKING:
from airflow.providers.common.compat.openlineage.facet import OutputDataset
diff --git
a/providers/google/src/airflow/providers/google/firebase/operators/firestore.py
b/providers/google/src/airflow/providers/google/firebase/operators/firestore.py
index 055bcadb92a..08ad61c7616 100644
---
a/providers/google/src/airflow/providers/google/firebase/operators/firestore.py
+++
b/providers/google/src/airflow/providers/google/firebase/operators/firestore.py
@@ -20,9 +20,9 @@ from collections.abc import Sequence
from typing import TYPE_CHECKING
from airflow.exceptions import AirflowException
-from airflow.models import BaseOperator
from airflow.providers.google.common.hooks.base_google import
PROVIDE_PROJECT_ID
from airflow.providers.google.firebase.hooks.firestore import
CloudFirestoreHook
+from airflow.providers.google.version_compat import BaseOperator
if TYPE_CHECKING:
from airflow.utils.context import Context
diff --git
a/providers/google/src/airflow/providers/google/leveldb/operators/leveldb.py
b/providers/google/src/airflow/providers/google/leveldb/operators/leveldb.py
index 2d544e89b45..77320d5e6bd 100644
--- a/providers/google/src/airflow/providers/google/leveldb/operators/leveldb.py
+++ b/providers/google/src/airflow/providers/google/leveldb/operators/leveldb.py
@@ -18,8 +18,8 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Any
-from airflow.models import BaseOperator
from airflow.providers.google.leveldb.hooks.leveldb import LevelDBHook
+from airflow.providers.google.version_compat import BaseOperator
if TYPE_CHECKING:
from airflow.utils.context import Context
diff --git
a/providers/google/src/airflow/providers/google/marketing_platform/links/analytics_admin.py
b/providers/google/src/airflow/providers/google/marketing_platform/links/analytics_admin.py
index 9c055889bd7..a783ddf0081 100644
---
a/providers/google/src/airflow/providers/google/marketing_platform/links/analytics_admin.py
+++
b/providers/google/src/airflow/providers/google/marketing_platform/links/analytics_admin.py
@@ -18,13 +18,12 @@ from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar
+from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS,
BaseOperator
+
if TYPE_CHECKING:
- from airflow.models import BaseOperator
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.utils.context import Context
-from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS
-
if AIRFLOW_V_3_0_PLUS:
from airflow.sdk import BaseOperatorLink
from airflow.sdk.execution_time.xcom import XCom
@@ -64,11 +63,9 @@ class GoogleAnalyticsPropertyLink(GoogleAnalyticsBaseLink):
@staticmethod
def persist(
context: Context,
- task_instance: BaseOperator,
property_id: str,
):
- task_instance.xcom_push(
- context,
+ context["task_instance"].xcom_push(
key=GoogleAnalyticsPropertyLink.key,
value={"property_id": property_id},
)
diff --git
a/providers/google/src/airflow/providers/google/marketing_platform/operators/analytics_admin.py
b/providers/google/src/airflow/providers/google/marketing_platform/operators/analytics_admin.py
index 8f662f9c92d..4465a6084aa 100644
---
a/providers/google/src/airflow/providers/google/marketing_platform/operators/analytics_admin.py
+++
b/providers/google/src/airflow/providers/google/marketing_platform/operators/analytics_admin.py
@@ -194,7 +194,6 @@ class
GoogleAnalyticsAdminCreatePropertyOperator(GoogleCloudBaseOperator):
self.log.info("The Google Analytics property %s was created
successfully.", prop.name)
GoogleAnalyticsPropertyLink.persist(
context=context,
- task_instance=self,
property_id=prop.name.lstrip("properties/"),
)
diff --git
a/providers/google/src/airflow/providers/google/marketing_platform/operators/campaign_manager.py
b/providers/google/src/airflow/providers/google/marketing_platform/operators/campaign_manager.py
index 4673f3c53a2..e68a2862848 100644
---
a/providers/google/src/airflow/providers/google/marketing_platform/operators/campaign_manager.py
+++
b/providers/google/src/airflow/providers/google/marketing_platform/operators/campaign_manager.py
@@ -28,9 +28,9 @@ from typing import TYPE_CHECKING, Any
from googleapiclient import http
from airflow.exceptions import AirflowException
-from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.providers.google.marketing_platform.hooks.campaign_manager import
GoogleCampaignManagerHook
+from airflow.providers.google.version_compat import BaseOperator
if TYPE_CHECKING:
from airflow.utils.context import Context
@@ -237,7 +237,7 @@ class
GoogleCampaignManagerDownloadReportOperator(BaseOperator):
mime_type="text/csv",
)
- self.xcom_push(context, key="report_name", value=report_name)
+ context["task_instance"].xcom_push(key="report_name",
value=report_name)
class GoogleCampaignManagerInsertReportOperator(BaseOperator):
@@ -308,7 +308,7 @@ class
GoogleCampaignManagerInsertReportOperator(BaseOperator):
self.log.info("Inserting Campaign Manager report.")
response = hook.insert_report(profile_id=self.profile_id,
report=self.report)
report_id = response.get("id")
- self.xcom_push(context, key="report_id", value=report_id)
+ context["task_instance"].xcom_push(key="report_id", value=report_id)
self.log.info("Report successfully inserted. Report id: %s", report_id)
return response
@@ -381,7 +381,7 @@ class GoogleCampaignManagerRunReportOperator(BaseOperator):
synchronous=self.synchronous,
)
file_id = response.get("id")
- self.xcom_push(context, key="file_id", value=file_id)
+ context["task_instance"].xcom_push(key="file_id", value=file_id)
self.log.info("Report file id: %s", file_id)
return response
diff --git
a/providers/google/src/airflow/providers/google/marketing_platform/operators/display_video.py
b/providers/google/src/airflow/providers/google/marketing_platform/operators/display_video.py
index a45ca8ae2d2..0c88db9f94e 100644
---
a/providers/google/src/airflow/providers/google/marketing_platform/operators/display_video.py
+++
b/providers/google/src/airflow/providers/google/marketing_platform/operators/display_video.py
@@ -29,9 +29,9 @@ from typing import TYPE_CHECKING, Any
from urllib.parse import urlsplit
from airflow.exceptions import AirflowException
-from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.providers.google.marketing_platform.hooks.display_video import
GoogleDisplayVideo360Hook
+from airflow.providers.google.version_compat import BaseOperator
if TYPE_CHECKING:
from airflow.utils.context import Context
@@ -99,7 +99,7 @@ class GoogleDisplayVideo360CreateQueryOperator(BaseOperator):
self.log.info("Creating Display & Video 360 query.")
response = hook.create_query(query=self.body)
query_id = response["queryId"]
- self.xcom_push(context, key="query_id", value=query_id)
+ context["task_instance"].xcom_push(key="query_id", value=query_id)
self.log.info("Created query with ID: %s", query_id)
return response
@@ -295,7 +295,7 @@ class
GoogleDisplayVideo360DownloadReportV2Operator(BaseOperator):
self.bucket_name,
report_name,
)
- self.xcom_push(context, key="report_name", value=report_name)
+ context["task_instance"].xcom_push(key="report_name",
value=report_name)
class GoogleDisplayVideo360RunQueryOperator(BaseOperator):
@@ -360,8 +360,8 @@ class GoogleDisplayVideo360RunQueryOperator(BaseOperator):
self.parameters,
)
response = hook.run_query(query_id=self.query_id,
params=self.parameters)
- self.xcom_push(context, key="query_id",
value=response["key"]["queryId"])
- self.xcom_push(context, key="report_id",
value=response["key"]["reportId"])
+ context["task_instance"].xcom_push(key="query_id",
value=response["key"]["queryId"])
+ context["task_instance"].xcom_push(key="report_id",
value=response["key"]["reportId"])
return response
@@ -564,7 +564,7 @@ class
GoogleDisplayVideo360CreateSDFDownloadTaskOperator(BaseOperator):
operation =
hook.create_sdf_download_operation(body_request=self.body_request)
name = operation["name"]
- self.xcom_push(context, key="name", value=name)
+ context["task_instance"].xcom_push(key="name", value=name)
self.log.info("Created SDF operation with name: %s", name)
return operation
diff --git
a/providers/google/src/airflow/providers/google/marketing_platform/operators/search_ads.py
b/providers/google/src/airflow/providers/google/marketing_platform/operators/search_ads.py
index 1c7e120aa37..7b22197596c 100644
---
a/providers/google/src/airflow/providers/google/marketing_platform/operators/search_ads.py
+++
b/providers/google/src/airflow/providers/google/marketing_platform/operators/search_ads.py
@@ -23,8 +23,8 @@ from collections.abc import Sequence
from functools import cached_property
from typing import TYPE_CHECKING, Any
-from airflow.models import BaseOperator
from airflow.providers.google.marketing_platform.hooks.search_ads import
GoogleSearchAdsReportingHook
+from airflow.providers.google.version_compat import BaseOperator
if TYPE_CHECKING:
from airflow.utils.context import Context
diff --git
a/providers/google/src/airflow/providers/google/suite/operators/sheets.py
b/providers/google/src/airflow/providers/google/suite/operators/sheets.py
index f6eaaf9d2ba..2cf942218c6 100644
--- a/providers/google/src/airflow/providers/google/suite/operators/sheets.py
+++ b/providers/google/src/airflow/providers/google/suite/operators/sheets.py
@@ -19,8 +19,8 @@ from __future__ import annotations
from collections.abc import Sequence
from typing import Any
-from airflow.models import BaseOperator
from airflow.providers.google.suite.hooks.sheets import GSheetsHook
+from airflow.providers.google.version_compat import BaseOperator
class GoogleSheetsCreateSpreadsheetOperator(BaseOperator):
@@ -68,6 +68,6 @@ class GoogleSheetsCreateSpreadsheetOperator(BaseOperator):
impersonation_chain=self.impersonation_chain,
)
spreadsheet = hook.create_spreadsheet(spreadsheet=self.spreadsheet)
- self.xcom_push(context, "spreadsheet_id", spreadsheet["spreadsheetId"])
- self.xcom_push(context, "spreadsheet_url",
spreadsheet["spreadsheetUrl"])
+ context["task_instance"].xcom_push(key="spreadsheet_id",
value=spreadsheet["spreadsheetId"])
+ context["task_instance"].xcom_push(key="spreadsheet_url",
value=spreadsheet["spreadsheetUrl"])
return spreadsheet
diff --git
a/providers/google/src/airflow/providers/google/suite/transfers/gcs_to_gdrive.py
b/providers/google/src/airflow/providers/google/suite/transfers/gcs_to_gdrive.py
index 2a8728e30ea..06854bca89d 100644
---
a/providers/google/src/airflow/providers/google/suite/transfers/gcs_to_gdrive.py
+++
b/providers/google/src/airflow/providers/google/suite/transfers/gcs_to_gdrive.py
@@ -24,9 +24,9 @@ from collections.abc import Sequence
from typing import TYPE_CHECKING
from airflow.exceptions import AirflowException
-from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.providers.google.suite.hooks.drive import GoogleDriveHook
+from airflow.providers.google.version_compat import BaseOperator
if TYPE_CHECKING:
from airflow.utils.context import Context
diff --git
a/providers/google/src/airflow/providers/google/suite/transfers/gcs_to_sheets.py
b/providers/google/src/airflow/providers/google/suite/transfers/gcs_to_sheets.py
index 95b9120df91..77f3fc03e42 100644
---
a/providers/google/src/airflow/providers/google/suite/transfers/gcs_to_sheets.py
+++
b/providers/google/src/airflow/providers/google/suite/transfers/gcs_to_sheets.py
@@ -21,9 +21,9 @@ from collections.abc import Sequence
from tempfile import NamedTemporaryFile
from typing import Any
-from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.providers.google.suite.hooks.sheets import GSheetsHook
+from airflow.providers.google.version_compat import BaseOperator
class GCSToGoogleSheetsOperator(BaseOperator):
diff --git
a/providers/google/src/airflow/providers/google/suite/transfers/local_to_drive.py
b/providers/google/src/airflow/providers/google/suite/transfers/local_to_drive.py
index 9a9e8317628..aa712d773da 100644
---
a/providers/google/src/airflow/providers/google/suite/transfers/local_to_drive.py
+++
b/providers/google/src/airflow/providers/google/suite/transfers/local_to_drive.py
@@ -24,8 +24,8 @@ from pathlib import Path
from typing import TYPE_CHECKING
from airflow.exceptions import AirflowFailException
-from airflow.models import BaseOperator
from airflow.providers.google.suite.hooks.drive import GoogleDriveHook
+from airflow.providers.google.version_compat import BaseOperator
if TYPE_CHECKING:
from airflow.utils.context import Context
diff --git a/providers/google/src/airflow/providers/google/version_compat.py
b/providers/google/src/airflow/providers/google/version_compat.py
index 48d122b6696..9b604e6b950 100644
--- a/providers/google/src/airflow/providers/google/version_compat.py
+++ b/providers/google/src/airflow/providers/google/version_compat.py
@@ -33,3 +33,31 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0)
+AIRFLOW_V_3_1_PLUS = get_base_airflow_version_tuple() >= (3, 1, 0)
+
+# Version-compatible imports
+# BaseOperator: Use 3.1+ due to xcom_push method missing in SDK BaseOperator
3.0.x
+# This is needed for DecoratedOperator compatibility
+if AIRFLOW_V_3_1_PLUS:
+ from airflow.sdk import BaseOperator
+else:
+ from airflow.models import BaseOperator
+
+# Other SDK components: Available since 3.0+
+if AIRFLOW_V_3_0_PLUS:
+ from airflow.sdk import (
+ BaseOperatorLink,
+ BaseSensorOperator,
+ )
+else:
+ from airflow.models import BaseOperatorLink # type: ignore[no-redef]
+ from airflow.sensors.base import BaseSensorOperator # type:
ignore[no-redef]
+
+# Explicitly export these imports to protect them from being removed by linters
+__all__ = [
+ "AIRFLOW_V_3_0_PLUS",
+ "AIRFLOW_V_3_1_PLUS",
+ "BaseOperator",
+ "BaseSensorOperator",
+ "BaseOperatorLink",
+]
diff --git
a/providers/google/tests/unit/google/cloud/operators/test_cloud_build.py
b/providers/google/tests/unit/google/cloud/operators/test_cloud_build.py
index b7ca33cae72..66602204431 100644
--- a/providers/google/tests/unit/google/cloud/operators/test_cloud_build.py
+++ b/providers/google/tests/unit/google/cloud/operators/test_cloud_build.py
@@ -426,7 +426,7 @@ def
test_async_create_build_fires_correct_trigger_should_execute_successfully(
)
with pytest.raises(TaskDeferred) as exc:
- ti.task.execute({"ti": ti})
+ ti.task.execute({"ti": ti, "task_instance": ti})
assert isinstance(exc.value.trigger, CloudBuildCreateBuildTrigger), (
"Trigger is not a CloudBuildCreateBuildTrigger"
diff --git
a/providers/google/tests/unit/google/cloud/operators/test_datacatalog.py
b/providers/google/tests/unit/google/cloud/operators/test_datacatalog.py
index 0026dd1f501..da5b420edb3 100644
--- a/providers/google/tests/unit/google/cloud/operators/test_datacatalog.py
+++ b/providers/google/tests/unit/google/cloud/operators/test_datacatalog.py
@@ -121,6 +121,7 @@ TEST_TAG_DICT: dict = {
}
TEST_TAG_TEMPLATE: TagTemplate = TagTemplate(name=TEST_TAG_TEMPLATE_PATH)
TEST_TAG_TEMPLATE_DICT: dict = {
+ "dataplex_transfer_status": 0,
"display_name": "",
"fields": {},
"is_publicly_readable": False,
@@ -145,8 +146,7 @@ class TestCloudDataCatalogCreateEntryOperator:
"airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook",
**{"return_value.create_entry.return_value": TEST_ENTRY},
)
-
@mock.patch(BASE_PATH.format("CloudDataCatalogCreateEntryOperator.xcom_push"))
- def test_assert_valid_hook_call(self, mock_xcom, mock_hook) -> None:
+ def test_assert_valid_hook_call(self, mock_hook) -> None:
with pytest.warns(AirflowProviderDeprecationWarning):
task = CloudDataCatalogCreateEntryOperator(
task_id="task_id",
@@ -161,8 +161,9 @@ class TestCloudDataCatalogCreateEntryOperator:
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=TEST_IMPERSONATION_CHAIN,
)
- context = mock.MagicMock()
- result = task.execute(context=context)
+ mock_ti = mock.MagicMock()
+ mock_context = {"ti": mock_ti}
+ result = task.execute(context=mock_context) # type: ignore[arg-type]
mock_hook.assert_called_once_with(
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=TEST_IMPERSONATION_CHAIN,
@@ -177,8 +178,7 @@ class TestCloudDataCatalogCreateEntryOperator:
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
)
- mock_xcom.assert_called_with(
- context,
+ mock_ti.xcom_push.assert_any_call(
key="entry_id",
value=TEST_ENTRY_ID,
)
@@ -186,8 +186,7 @@ class TestCloudDataCatalogCreateEntryOperator:
assert result == TEST_ENTRY_DICT
@mock.patch("airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook")
-
@mock.patch(BASE_PATH.format("CloudDataCatalogCreateEntryOperator.xcom_push"))
- def test_assert_valid_hook_call_when_exists(self, mock_xcom, mock_hook) ->
None:
+ def test_assert_valid_hook_call_when_exists(self, mock_hook) -> None:
mock_hook.return_value.create_entry.side_effect =
AlreadyExists(message="message")
mock_hook.return_value.get_entry.return_value = TEST_ENTRY
with pytest.warns(AirflowProviderDeprecationWarning):
@@ -204,8 +203,9 @@ class TestCloudDataCatalogCreateEntryOperator:
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=TEST_IMPERSONATION_CHAIN,
)
- context = mock.MagicMock()
- result = task.execute(context=context)
+ mock_ti = mock.MagicMock()
+ mock_context = {"ti": mock_ti}
+ result = task.execute(context=mock_context) # type: ignore[arg-type]
mock_hook.assert_called_once_with(
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=TEST_IMPERSONATION_CHAIN,
@@ -229,8 +229,7 @@ class TestCloudDataCatalogCreateEntryOperator:
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
)
- mock_xcom.assert_called_with(
- context,
+ mock_ti.xcom_push.assert_any_call(
key="entry_id",
value=TEST_ENTRY_ID,
)
@@ -242,8 +241,7 @@ class TestCloudDataCatalogCreateEntryGroupOperator:
"airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook",
**{"return_value.create_entry_group.return_value": TEST_ENTRY_GROUP},
)
-
@mock.patch(BASE_PATH.format("CloudDataCatalogCreateEntryGroupOperator.xcom_push"))
- def test_assert_valid_hook_call(self, mock_xcom, mock_hook) -> None:
+ def test_assert_valid_hook_call(self, mock_hook) -> None:
with pytest.warns(AirflowProviderDeprecationWarning):
task = CloudDataCatalogCreateEntryGroupOperator(
task_id="task_id",
@@ -257,8 +255,9 @@ class TestCloudDataCatalogCreateEntryGroupOperator:
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=TEST_IMPERSONATION_CHAIN,
)
- context = mock.MagicMock()
- result = task.execute(context=context)
+ mock_ti = mock.MagicMock()
+ mock_context = {"ti": mock_ti}
+ result = task.execute(context=mock_context) # type: ignore[arg-type]
mock_hook.assert_called_once_with(
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=TEST_IMPERSONATION_CHAIN,
@@ -272,8 +271,7 @@ class TestCloudDataCatalogCreateEntryGroupOperator:
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
)
- mock_xcom.assert_called_with(
- context,
+ mock_ti.xcom_push.assert_any_call(
key="entry_group_id",
value=TEST_ENTRY_GROUP_ID,
)
@@ -285,8 +283,7 @@ class TestCloudDataCatalogCreateTagOperator:
"airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook",
**{"return_value.create_tag.return_value": TEST_TAG},
)
-
@mock.patch(BASE_PATH.format("CloudDataCatalogCreateTagOperator.xcom_push"))
- def test_assert_valid_hook_call(self, mock_xcom, mock_hook) -> None:
+ def test_assert_valid_hook_call(self, mock_hook) -> None:
with pytest.warns(AirflowProviderDeprecationWarning):
task = CloudDataCatalogCreateTagOperator(
task_id="task_id",
@@ -302,8 +299,9 @@ class TestCloudDataCatalogCreateTagOperator:
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=TEST_IMPERSONATION_CHAIN,
)
- context = mock.MagicMock()
- result = task.execute(context=context)
+ mock_ti = mock.MagicMock()
+ mock_context = {"ti": mock_ti}
+ result = task.execute(context=mock_context) # type: ignore[arg-type]
mock_hook.assert_called_once_with(
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=TEST_IMPERSONATION_CHAIN,
@@ -319,8 +317,7 @@ class TestCloudDataCatalogCreateTagOperator:
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
)
- mock_xcom.assert_called_with(
- context,
+ mock_ti.xcom_push.assert_any_call(
key="tag_id",
value=TEST_TAG_ID,
)
@@ -332,8 +329,7 @@ class TestCloudDataCatalogCreateTagTemplateOperator:
"airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook",
**{"return_value.create_tag_template.return_value": TEST_TAG_TEMPLATE},
)
-
@mock.patch(BASE_PATH.format("CloudDataCatalogCreateTagTemplateOperator.xcom_push"))
- def test_assert_valid_hook_call(self, mock_xcom, mock_hook) -> None:
+ def test_assert_valid_hook_call(self, mock_hook) -> None:
with pytest.warns(AirflowProviderDeprecationWarning):
task = CloudDataCatalogCreateTagTemplateOperator(
task_id="task_id",
@@ -347,8 +343,9 @@ class TestCloudDataCatalogCreateTagTemplateOperator:
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=TEST_IMPERSONATION_CHAIN,
)
- context = mock.MagicMock()
- result = task.execute(context=context)
+ mock_ti = mock.MagicMock()
+ mock_context = {"ti": mock_ti}
+ result = task.execute(context=mock_context) # type: ignore[arg-type]
mock_hook.assert_called_once_with(
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=TEST_IMPERSONATION_CHAIN,
@@ -362,12 +359,11 @@ class TestCloudDataCatalogCreateTagTemplateOperator:
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
)
- mock_xcom.assert_called_with(
- context,
+ mock_ti.xcom_push.assert_any_call(
key="tag_template_id",
value=TEST_TAG_TEMPLATE_ID,
)
- assert result == {**result, **TEST_TAG_TEMPLATE_DICT}
+ assert result == TEST_TAG_TEMPLATE_DICT
class TestCloudDataCatalogCreateTagTemplateFieldOperator:
@@ -375,8 +371,7 @@ class TestCloudDataCatalogCreateTagTemplateFieldOperator:
"airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook",
**{"return_value.create_tag_template_field.return_value":
TEST_TAG_TEMPLATE_FIELD}, # type: ignore
)
-
@mock.patch(BASE_PATH.format("CloudDataCatalogCreateTagTemplateFieldOperator.xcom_push"))
- def test_assert_valid_hook_call(self, mock_xcom, mock_hook) -> None:
+ def test_assert_valid_hook_call(self, mock_hook) -> None:
with pytest.warns(AirflowProviderDeprecationWarning):
task = CloudDataCatalogCreateTagTemplateFieldOperator(
task_id="task_id",
@@ -391,8 +386,9 @@ class TestCloudDataCatalogCreateTagTemplateFieldOperator:
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=TEST_IMPERSONATION_CHAIN,
)
- context = mock.MagicMock()
- result = task.execute(context=context)
+ mock_ti = mock.MagicMock()
+ mock_context = {"ti": mock_ti}
+ result = task.execute(context=mock_context) # type: ignore[arg-type]
mock_hook.assert_called_once_with(
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=TEST_IMPERSONATION_CHAIN,
@@ -407,12 +403,11 @@ class TestCloudDataCatalogCreateTagTemplateFieldOperator:
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
)
- mock_xcom.assert_called_with(
- context,
+ mock_ti.xcom_push.assert_any_call(
key="tag_template_field_id",
value=TEST_TAG_TEMPLATE_FIELD_ID,
)
- assert result == {**result, **TEST_TAG_TEMPLATE_FIELD_DICT}
+ assert result == TEST_TAG_TEMPLATE_FIELD_DICT
class TestCloudDataCatalogDeleteEntryOperator:
diff --git
a/providers/google/tests/unit/google/cloud/operators/test_dataflow.py
b/providers/google/tests/unit/google/cloud/operators/test_dataflow.py
index a9bc77b02b6..c7e7622b65f 100644
--- a/providers/google/tests/unit/google/cloud/operators/test_dataflow.py
+++ b/providers/google/tests/unit/google/cloud/operators/test_dataflow.py
@@ -161,11 +161,11 @@ class TestDataflowTemplatedJobStartOperator:
cancel_timeout=CANCEL_TIMEOUT,
)
- @mock.patch(f"{DATAFLOW_PATH}.DataflowTemplatedJobStartOperator.xcom_push")
@mock.patch(f"{DATAFLOW_PATH}.DataflowHook")
- def test_execute(self, hook_mock, mock_xcom_push, sync_operator):
+ def test_execute(self, hook_mock, sync_operator):
start_template_hook = hook_mock.return_value.start_template_dataflow
- sync_operator.execute(None)
+ mock_context = {"task_instance": mock.MagicMock()}
+ sync_operator.execute(mock_context)
assert hook_mock.called
expected_options = {
"project": "test",
@@ -231,9 +231,8 @@ class TestDataflowTemplatedJobStartOperator:
DataflowTemplatedJobStartOperator(**init_kwargs)
@pytest.mark.db_test
- @mock.patch(f"{DATAFLOW_PATH}.DataflowTemplatedJobStartOperator.xcom_push")
@mock.patch(f"{DATAFLOW_PATH}.DataflowHook.start_template_dataflow")
- def test_start_with_custom_region(self, dataflow_mock, mock_xcom_push):
+ def test_start_with_custom_region(self, dataflow_mock):
init_kwargs = {
"task_id": TASK_ID,
"template": TEMPLATE,
@@ -245,16 +244,16 @@ class TestDataflowTemplatedJobStartOperator:
"cancel_timeout": CANCEL_TIMEOUT,
}
operator = DataflowTemplatedJobStartOperator(**init_kwargs)
- operator.execute(None)
+ mock_context = {"task_instance": mock.MagicMock()}
+ operator.execute(mock_context)
assert dataflow_mock.called
_, kwargs = dataflow_mock.call_args_list[0]
assert kwargs["variables"]["region"] == TEST_REGION
assert kwargs["location"] == DEFAULT_DATAFLOW_LOCATION
@pytest.mark.db_test
- @mock.patch(f"{DATAFLOW_PATH}.DataflowTemplatedJobStartOperator.xcom_push")
@mock.patch(f"{DATAFLOW_PATH}.DataflowHook.start_template_dataflow")
- def test_start_with_location(self, dataflow_mock, mock_xcom_push):
+ def test_start_with_location(self, dataflow_mock):
init_kwargs = {
"task_id": TASK_ID,
"template": TEMPLATE,
@@ -264,7 +263,8 @@ class TestDataflowTemplatedJobStartOperator:
"cancel_timeout": CANCEL_TIMEOUT,
}
operator = DataflowTemplatedJobStartOperator(**init_kwargs)
- operator.execute(None)
+ mock_context = {"task_instance": mock.MagicMock()}
+ operator.execute(mock_context)
assert dataflow_mock.called
_, kwargs = dataflow_mock.call_args_list[0]
assert not kwargs["variables"]
@@ -409,19 +409,18 @@ class TestDataflowStartYamlJobOperator:
)
mock_defer_method.assert_called_once()
- @mock.patch(f"{DATAFLOW_PATH}.DataflowStartYamlJobOperator.xcom_push")
@mock.patch(f"{DATAFLOW_PATH}.DataflowHook")
- def test_execute_complete_success(self, mock_hook, mock_xcom_push,
deferrable_operator):
+ def test_execute_complete_success(self, mock_hook, deferrable_operator):
expected_result = {"id": JOB_ID}
+ mock_context = {"task_instance": mock.MagicMock()}
actual_result = deferrable_operator.execute_complete(
- context=None,
+ context=mock_context,
event={
"status": "success",
"message": "Batch job completed.",
"job": expected_result,
},
)
- mock_xcom_push.assert_called_with(None, key="job_id", value=JOB_ID)
assert actual_result == expected_result
def test_execute_complete_error_status_raises_exception(self,
deferrable_operator):
@@ -449,7 +448,8 @@ class TestDataflowStopJobOperator:
Test DataflowHook is created and the right args are passed to
cancel_job.
"""
cancel_job_hook = dataflow_mock.return_value.cancel_job
- self.dataflow.execute(None)
+ mock_context = {"task_instance": mock.MagicMock()}
+ self.dataflow.execute(mock_context)
assert dataflow_mock.called
cancel_job_hook.assert_called_once_with(
job_name=None,
@@ -473,7 +473,8 @@ class TestDataflowStopJobOperator:
"""
is_job_running_hook =
dataflow_mock.return_value.is_job_dataflow_running
cancel_job_hook = dataflow_mock.return_value.cancel_job
- self.dataflow.execute(None)
+ mock_context = {"task_instance": mock.MagicMock()}
+ self.dataflow.execute(mock_context)
assert dataflow_mock.called
is_job_running_hook.assert_called_once_with(
name=JOB_NAME,
diff --git
a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py
b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py
index 00923cb590a..91c2c0a3512 100644
--- a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py
+++ b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py
@@ -69,11 +69,11 @@ from airflow.providers.google.cloud.triggers.dataproc
import (
DataprocSubmitTrigger,
)
from airflow.providers.google.common.consts import
GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
+from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.serialization.serialized_objects import SerializedDAG
from airflow.utils.timezone import datetime
from tests_common.test_utils.db import clear_db_runs, clear_db_xcom
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
if AIRFLOW_V_3_0_PLUS:
from airflow.sdk.execution_time.comms import XComResult
diff --git
a/providers/google/tests/unit/google/cloud/operators/test_functions.py
b/providers/google/tests/unit/google/cloud/operators/test_functions.py
index b0752365893..47b3e4ebde6 100644
--- a/providers/google/tests/unit/google/cloud/operators/test_functions.py
+++ b/providers/google/tests/unit/google/cloud/operators/test_functions.py
@@ -693,9 +693,8 @@ class TestGcfFunctionDelete:
class TestGcfFunctionInvokeOperator:
-
@mock.patch("airflow.providers.google.cloud.operators.functions.GoogleCloudBaseOperator.xcom_push")
@mock.patch("airflow.providers.google.cloud.operators.functions.CloudFunctionsHook")
- def test_execute(self, mock_gcf_hook, mock_xcom):
+ def test_execute(self, mock_gcf_hook):
exec_id = "exec_id"
mock_gcf_hook.return_value.call_function.return_value =
{"executionId": exec_id}
@@ -715,8 +714,9 @@ class TestGcfFunctionInvokeOperator:
gcp_conn_id=gcp_conn_id,
impersonation_chain=impersonation_chain,
)
- context = mock.MagicMock()
- op.execute(context=context)
+ mock_ti = mock.MagicMock()
+ mock_context = {"ti": mock_ti}
+ op.execute(mock_context)
mock_gcf_hook.assert_called_once_with(
api_version=api_version,
@@ -728,8 +728,7 @@ class TestGcfFunctionInvokeOperator:
function_id=function_id, input_data=payload,
location=GCP_LOCATION, project_id=GCP_PROJECT_ID
)
- mock_xcom.assert_called_with(
- context=context,
+ mock_ti.xcom_push.assert_any_call(
key="execution_id",
value=exec_id,
)
diff --git
a/providers/google/tests/unit/google/cloud/operators/test_translate.py
b/providers/google/tests/unit/google/cloud/operators/test_translate.py
index 957bd3ddd1b..4e20d8ee2ac 100644
--- a/providers/google/tests/unit/google/cloud/operators/test_translate.py
+++ b/providers/google/tests/unit/google/cloud/operators/test_translate.py
@@ -214,9 +214,8 @@ class TestTranslateTextBatchOperator:
class TestTranslateDatasetCreate:
@mock.patch("airflow.providers.google.cloud.operators.translate.TranslationNativeDatasetLink.persist")
-
@mock.patch("airflow.providers.google.cloud.operators.translate.TranslateCreateDatasetOperator.xcom_push")
@mock.patch("airflow.providers.google.cloud.operators.translate.TranslateHook")
- def test_minimal_green_path(self, mock_hook, mock_xcom_push,
mock_link_persist):
+ def test_minimal_green_path(self, mock_hook, mock_link_persist):
DS_CREATION_RESULT_SAMPLE = {
"display_name": "",
"example_count": 0,
@@ -249,8 +248,9 @@ class TestTranslateDatasetCreate:
timeout=TIMEOUT_VALUE,
retry=None,
)
- context = mock.MagicMock()
- result = op.execute(context=context)
+ mock_ti = mock.MagicMock()
+ mock_context = {"ti": mock_ti}
+ result = op.execute(context=mock_context)
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
@@ -263,9 +263,9 @@ class TestTranslateDatasetCreate:
retry=None,
metadata=(),
)
- mock_xcom_push.assert_called_once_with(context, key="dataset_id",
value=DATASET_ID)
+ mock_ti.xcom_push.assert_any_call(key="dataset_id", value=DATASET_ID)
mock_link_persist.assert_called_once_with(
- context=context,
+ context=mock_context,
dataset_id=DATASET_ID,
location=LOCATION,
project_id=PROJECT_ID,
@@ -402,9 +402,8 @@ class TestTranslateDeleteData:
class TestTranslateModelCreate:
@mock.patch("airflow.providers.google.cloud.links.translate.TranslationModelLink.persist")
-
@mock.patch("airflow.providers.google.cloud.operators.translate.TranslateCreateModelOperator.xcom_push")
@mock.patch("airflow.providers.google.cloud.operators.translate.TranslateHook")
- def test_minimal_green_path(self, mock_hook, mock_xcom_push,
mock_link_persist):
+ def test_minimal_green_path(self, mock_hook, mock_link_persist):
MODEL_DISPLAY_NAME = "model_display_name_01"
MODEL_CREATION_RESULT_SAMPLE = {
"display_name": MODEL_DISPLAY_NAME,
@@ -435,8 +434,9 @@ class TestTranslateModelCreate:
timeout=TIMEOUT_VALUE,
retry=None,
)
- context = mock.MagicMock()
- result = op.execute(context=context)
+ mock_ti = mock.MagicMock()
+ mock_context = {"ti": mock_ti}
+ result = op.execute(context=mock_context)
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
@@ -450,9 +450,9 @@ class TestTranslateModelCreate:
retry=None,
metadata=(),
)
- mock_xcom_push.assert_called_once_with(context, key="model_id",
value=MODEL_ID)
+ mock_ti.xcom_push.assert_any_call(key="model_id", value=MODEL_ID)
mock_link_persist.assert_called_once_with(
- context=context,
+ context=mock_context,
model_id=MODEL_ID,
project_id=PROJECT_ID,
dataset_id=DATASET_ID,
@@ -711,11 +711,8 @@ class TestTranslateDocumentOperator:
class TestTranslateGlossaryCreate:
- @mock.patch(
-
"airflow.providers.google.cloud.operators.translate.TranslateCreateGlossaryOperator.xcom_push"
- )
@mock.patch("airflow.providers.google.cloud.operators.translate.TranslateHook")
- def test_minimal_green_path(self, mock_hook, mock_xcom_push):
+ def test_minimal_green_path(self, mock_hook):
GLOSSARY_CREATION_RESULT = {
"name":
f"projects/{PROJECT_ID}/locations/{LOCATION}/glossaries/{GLOSSARY_ID}",
"display_name": f"{GLOSSARY_ID}",
@@ -746,8 +743,9 @@ class TestTranslateGlossaryCreate:
timeout=TIMEOUT_VALUE,
retry=None,
)
- context = mock.MagicMock()
- result = op.execute(context=context)
+ mock_ti = mock.MagicMock()
+ mock_context = {"ti": mock_ti}
+ result = op.execute(context=mock_context)
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
@@ -764,7 +762,7 @@ class TestTranslateGlossaryCreate:
retry=None,
metadata=(),
)
- mock_xcom_push.assert_called_once_with(context, key="glossary_id",
value=GLOSSARY_ID)
+ mock_ti.xcom_push.assert_any_call(key="glossary_id", value=GLOSSARY_ID)
assert result == GLOSSARY_CREATION_RESULT
diff --git
a/providers/google/tests/unit/google/cloud/operators/test_vertex_ai.py
b/providers/google/tests/unit/google/cloud/operators/test_vertex_ai.py
index 67d11428cb7..460e8b24847 100644
--- a/providers/google/tests/unit/google/cloud/operators/test_vertex_ai.py
+++ b/providers/google/tests/unit/google/cloud/operators/test_vertex_ai.py
@@ -439,14 +439,11 @@ class
TestVertexAICreateCustomContainerTrainingJobOperator:
)
@mock.patch(VERTEX_AI_LINKS_PATH.format("VertexAIModelLink.persist"))
-
@mock.patch(VERTEX_AI_PATH.format("custom_job.CreateCustomContainerTrainingJobOperator.xcom_push"))
@mock.patch(
VERTEX_AI_PATH.format("custom_job.CreateCustomContainerTrainingJobOperator.hook.extract_model_id")
)
@mock.patch(VERTEX_AI_PATH.format("custom_job.CreateCustomContainerTrainingJobOperator.hook"))
- def test_execute_complete_success(
- self, mock_hook, mock_hook_extract_model_id, mock_xcom_push,
mock_link_persist
- ):
+ def test_execute_complete_success(self, mock_hook,
mock_hook_extract_model_id, mock_link_persist):
task = CreateCustomContainerTrainingJobOperator(
task_id=TASK_ID,
gcp_conn_id=GCP_CONN_ID,
@@ -471,16 +468,20 @@ class
TestVertexAICreateCustomContainerTrainingJobOperator:
)
expected_result = TEST_TRAINING_PIPELINE_DATA["model_to_upload"]
mock_hook_extract_model_id.return_value = "test-model"
+
+ mock_ti = mock.MagicMock()
+ mock_context = {"ti": mock_ti}
+
actual_result = task.execute_complete(
- context=None,
+ context=mock_context,
event={
"status": "success",
"message": "",
"job": TEST_TRAINING_PIPELINE_DATA,
},
)
- mock_xcom_push.assert_called_with(None, key="model_id",
value="test-model")
- mock_link_persist.assert_called_once_with(context=None,
model_id="test-model")
+ mock_ti.xcom_push.assert_any_call(key="model_id", value="test-model")
+ mock_link_persist.assert_called_once_with(context=mock_context,
model_id="test-model")
assert actual_result == expected_result
def test_execute_complete_error_status_raises_exception(self):
@@ -510,7 +511,6 @@ class TestVertexAICreateCustomContainerTrainingJobOperator:
task.execute_complete(context=None, event={"status": "error",
"message": "test message"})
@mock.patch(VERTEX_AI_LINKS_PATH.format("VertexAIModelLink.persist"))
-
@mock.patch(VERTEX_AI_PATH.format("custom_job.CreateCustomContainerTrainingJobOperator.xcom_push"))
@mock.patch(
VERTEX_AI_PATH.format("custom_job.CreateCustomContainerTrainingJobOperator.hook.extract_model_id")
)
@@ -519,7 +519,6 @@ class TestVertexAICreateCustomContainerTrainingJobOperator:
self,
mock_hook,
hook_extract_model_id,
- mock_xcom_push,
mock_link_persist,
):
task = CreateCustomContainerTrainingJobOperator(
@@ -544,11 +543,16 @@ class
TestVertexAICreateCustomContainerTrainingJobOperator:
)
expected_result = None
hook_extract_model_id.return_value = None
+
+ mock_ti = mock.MagicMock()
+ mock_context = {"ti": mock_ti}
+
actual_result = task.execute_complete(
- context=None,
+ context=mock_context,
event={"status": "success", "message": "", "job":
TEST_TRAINING_PIPELINE_DATA_NO_MODEL},
)
- mock_xcom_push.assert_called_once()
+ # When no model is produced, xcom_push should still be called but with
None value
+ mock_ti.xcom_push.assert_called_once()
mock_link_persist.assert_not_called()
assert actual_result == expected_result
@@ -765,7 +769,6 @@ class
TestVertexAICreateCustomPythonPackageTrainingJobOperator:
)
@mock.patch(VERTEX_AI_LINKS_PATH.format("VertexAIModelLink.persist"))
-
@mock.patch(VERTEX_AI_PATH.format("custom_job.CreateCustomPythonPackageTrainingJobOperator.xcom_push"))
@mock.patch(
VERTEX_AI_PATH.format("custom_job.CreateCustomPythonPackageTrainingJobOperator.hook.extract_model_id")
)
@@ -774,7 +777,6 @@ class
TestVertexAICreateCustomPythonPackageTrainingJobOperator:
self,
mock_hook,
hook_extract_model_id,
- mock_xcom_push,
mock_link_persist,
):
task = CreateCustomPythonPackageTrainingJobOperator(
@@ -802,16 +804,20 @@ class
TestVertexAICreateCustomPythonPackageTrainingJobOperator:
)
expected_result = TEST_TRAINING_PIPELINE_DATA["model_to_upload"]
hook_extract_model_id.return_value = "test-model"
+
+ mock_ti = mock.MagicMock()
+ mock_context = {"ti": mock_ti}
+
actual_result = task.execute_complete(
- context=None,
+ context=mock_context,
event={
"status": "success",
"message": "",
"job": TEST_TRAINING_PIPELINE_DATA,
},
)
- mock_xcom_push.assert_called_with(None, key="model_id",
value="test-model")
- mock_link_persist.assert_called_once_with(context=None,
model_id="test-model")
+ mock_ti.xcom_push.assert_any_call(key="model_id", value="test-model")
+ mock_link_persist.assert_called_once_with(context=mock_context,
model_id="test-model")
assert actual_result == expected_result
def test_execute_complete_error_status_raises_exception(self):
@@ -842,7 +848,6 @@ class
TestVertexAICreateCustomPythonPackageTrainingJobOperator:
task.execute_complete(context=None, event={"status": "error",
"message": "test message"})
@mock.patch(VERTEX_AI_LINKS_PATH.format("VertexAIModelLink.persist"))
-
@mock.patch(VERTEX_AI_PATH.format("custom_job.CreateCustomPythonPackageTrainingJobOperator.xcom_push"))
@mock.patch(
VERTEX_AI_PATH.format("custom_job.CreateCustomPythonPackageTrainingJobOperator.hook.extract_model_id")
)
@@ -851,7 +856,6 @@ class
TestVertexAICreateCustomPythonPackageTrainingJobOperator:
self,
mock_hook,
hook_extract_model_id,
- mock_xcom_push,
mock_link_persist,
):
task = CreateCustomPythonPackageTrainingJobOperator(
@@ -875,12 +879,15 @@ class
TestVertexAICreateCustomPythonPackageTrainingJobOperator:
project_id=GCP_PROJECT,
deferrable=True,
)
+
+ mock_ti = mock.MagicMock()
+
expected_result = None
actual_result = task.execute_complete(
- context=None,
+ context={"ti": mock_ti},
event={"status": "success", "message": "", "job":
TEST_TRAINING_PIPELINE_DATA_NO_MODEL},
)
- mock_xcom_push.assert_called_once()
+ mock_ti.xcom_push.assert_called_once()
mock_link_persist.assert_not_called()
assert actual_result == expected_result
@@ -1076,14 +1083,12 @@ class TestVertexAICreateCustomTrainingJobOperator:
)
@mock.patch(VERTEX_AI_LINKS_PATH.format("VertexAIModelLink.persist"))
-
@mock.patch(VERTEX_AI_PATH.format("custom_job.CreateCustomTrainingJobOperator.xcom_push"))
@mock.patch(VERTEX_AI_PATH.format("custom_job.CreateCustomTrainingJobOperator.hook.extract_model_id"))
@mock.patch(VERTEX_AI_PATH.format("custom_job.CreateCustomTrainingJobOperator.hook"))
def test_execute_complete_success(
self,
mock_hook,
hook_extract_model_id,
- mock_xcom_push,
mock_link_persist,
):
task = CreateCustomTrainingJobOperator(
@@ -1104,16 +1109,18 @@ class TestVertexAICreateCustomTrainingJobOperator:
)
expected_result = TEST_TRAINING_PIPELINE_DATA["model_to_upload"]
hook_extract_model_id.return_value = "test-model"
+
+ mock_ti = mock.MagicMock()
actual_result = task.execute_complete(
- context=None,
+ context={"ti": mock_ti},
event={
"status": "success",
"message": "",
"job": TEST_TRAINING_PIPELINE_DATA,
},
)
- mock_xcom_push.assert_called_with(None, key="model_id",
value="test-model")
- mock_link_persist.assert_called_once_with(context=None,
model_id="test-model")
+ mock_ti.xcom_push.assert_called_with(key="model_id",
value="test-model")
+ mock_link_persist.assert_called_once_with(context={"ti": mock_ti},
model_id="test-model")
assert actual_result == expected_result
def test_execute_complete_error_status_raises_exception(self):
@@ -1127,7 +1134,6 @@ class TestVertexAICreateCustomTrainingJobOperator:
args=PYTHON_PACKAGE_CMDARGS,
container_uri=CONTAINER_URI,
model_serving_container_image_uri=CONTAINER_URI,
- requirements=[],
replica_count=1,
region=GCP_LOCATION,
project_id=GCP_PROJECT,
@@ -1137,14 +1143,12 @@ class TestVertexAICreateCustomTrainingJobOperator:
task.execute_complete(context=None, event={"status": "error",
"message": "test message"})
@mock.patch(VERTEX_AI_LINKS_PATH.format("VertexAIModelLink.persist"))
-
@mock.patch(VERTEX_AI_PATH.format("custom_job.CreateCustomTrainingJobOperator.xcom_push"))
@mock.patch(VERTEX_AI_PATH.format("custom_job.CreateCustomTrainingJobOperator.hook.extract_model_id"))
@mock.patch(VERTEX_AI_PATH.format("custom_job.CreateCustomTrainingJobOperator.hook"))
def test_execute_complete_no_model_produced(
self,
mock_hook,
hook_extract_model_id,
- mock_xcom_push,
mock_link_persist,
):
task = CreateCustomTrainingJobOperator(
@@ -1164,10 +1168,14 @@ class TestVertexAICreateCustomTrainingJobOperator:
)
expected_result = None
hook_extract_model_id.return_value = None
+
+ mock_ti = mock.MagicMock()
+ mock_context = {"ti": mock_ti}
+
actual_result = task.execute_complete(
- context=None, event={"status": "success", "message": "", "job": {}}
+ context=mock_context, event={"status": "success", "message": "",
"job": {}}
)
- mock_xcom_push.assert_called_once()
+ mock_ti.xcom_push.assert_called_once()
mock_link_persist.assert_not_called()
assert actual_result == expected_result
@@ -2115,10 +2123,10 @@ class TestVertexAICreateBatchPredictionJobOperator:
assert exception_info.value.trigger.poll_interval == 10
assert exception_info.value.trigger.impersonation_chain ==
IMPERSONATION_CHAIN
-
@mock.patch(VERTEX_AI_PATH.format("batch_prediction_job.CreateBatchPredictionJobOperator.xcom_push"))
@mock.patch(VERTEX_AI_PATH.format("batch_prediction_job.BatchPredictionJobHook"))
- def test_execute_complete(self, mock_hook, mock_xcom_push):
- context = mock.MagicMock()
+ def test_execute_complete(self, mock_hook):
+ mock_ti = mock.MagicMock()
+ mock_context = {"ti": mock_ti}
mock_job = {"name": TEST_JOB_DISPLAY_NAME}
event = {
"status": "success",
@@ -2139,14 +2147,13 @@ class TestVertexAICreateBatchPredictionJobOperator:
create_request_timeout=TEST_CREATE_REQUEST_TIMEOUT,
batch_size=TEST_BATCH_SIZE,
)
- execute_complete_result = op.execute_complete(context=context,
event=event)
+ execute_complete_result = op.execute_complete(context=mock_context,
event=event)
mock_hook.return_value.extract_batch_prediction_job_id.assert_called_once_with(mock_job)
- mock_xcom_push.assert_has_calls(
+ mock_ti.xcom_push.assert_has_calls(
[
- call(context, key="batch_prediction_job_id",
value=TEST_BATCH_PREDICTION_JOB_ID),
+ call(key="batch_prediction_job_id",
value=TEST_BATCH_PREDICTION_JOB_ID),
call(
- context,
key="training_conf",
value={
"training_conf_id": TEST_BATCH_PREDICTION_JOB_ID,
@@ -2969,9 +2976,8 @@ class TestVertexAIRunPipelineJobOperator:
task.execute(context={"ti": mock.MagicMock(), "task":
mock.MagicMock()})
assert isinstance(exc.value.trigger, RunPipelineJobTrigger), "Trigger
is not a RunPipelineJobTrigger"
-
@mock.patch(VERTEX_AI_PATH.format("pipeline_job.RunPipelineJobOperator.xcom_push"))
@mock.patch(VERTEX_AI_PATH.format("pipeline_job.PipelineJobHook"))
- def test_execute_complete_success(self, mock_hook, mock_xcom_push):
+ def test_execute_complete_success(self, mock_hook):
task = RunPipelineJobOperator(
task_id=TASK_ID,
gcp_conn_id=GCP_CONN_ID,
@@ -2987,9 +2993,12 @@ class TestVertexAIRunPipelineJobOperator:
"name":
f"projects/{GCP_PROJECT}/locations/{GCP_LOCATION}/pipelineJobs/{TEST_PIPELINE_JOB_ID}",
}
mock_hook.return_value.exists.return_value = False
- mock_xcom_push.return_value = None
+
+ mock_ti = mock.MagicMock()
+ mock_context = {"ti": mock_ti}
+
actual_result = task.execute_complete(
- context=None, event={"status": "success", "message": "", "job":
expected_pipeline_job}
+ context=mock_context, event={"status": "success", "message": "",
"job": expected_pipeline_job}
)
assert actual_result == expected_result
diff --git
a/providers/google/tests/unit/google/cloud/transfers/test_sheets_to_gcs.py
b/providers/google/tests/unit/google/cloud/transfers/test_sheets_to_gcs.py
index 797203ae818..5e4446a6c20 100644
--- a/providers/google/tests/unit/google/cloud/transfers/test_sheets_to_gcs.py
+++ b/providers/google/tests/unit/google/cloud/transfers/test_sheets_to_gcs.py
@@ -78,14 +78,15 @@ class TestGoogleSheetsToGCSOperator:
@mock.patch("airflow.providers.google.cloud.transfers.sheets_to_gcs.GCSHook")
@mock.patch("airflow.providers.google.cloud.transfers.sheets_to_gcs.GSheetsHook")
-
@mock.patch("airflow.providers.google.cloud.transfers.sheets_to_gcs.GoogleSheetsToGCSOperator.xcom_push")
@mock.patch(
"airflow.providers.google.cloud.transfers.sheets_to_gcs.GoogleSheetsToGCSOperator._upload_data"
)
- def test_execute(self, mock_upload_data, mock_xcom, mock_sheet_hook,
mock_gcs_hook):
- context = {}
+ def test_execute(self, mock_upload_data, mock_sheet_hook, mock_gcs_hook):
+ mock_ti = mock.MagicMock()
+ mock_context = {"ti": mock_ti}
data = ["data1", "data2"]
mock_sheet_hook.return_value.get_sheet_titles.return_value = RANGES
+ mock_sheet_hook.return_value.get_values.side_effect = data
mock_upload_data.side_effect = [PATH, PATH]
op = GoogleSheetsToGCSOperator(
@@ -97,7 +98,7 @@ class TestGoogleSheetsToGCSOperator:
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
- op.execute(context)
+ op.execute(mock_context)
mock_sheet_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
@@ -115,9 +116,12 @@ class TestGoogleSheetsToGCSOperator:
calls = [mock.call(spreadsheet_id=SPREADSHEET_ID, range_=r) for r in
RANGES]
mock_sheet_hook.return_value.get_values.assert_has_calls(calls)
- calls = [mock.call(mock_gcs_hook, mock_sheet_hook, r, v) for r, v in
zip(RANGES, data)]
- mock_upload_data.assert_called()
+ calls = [
+ mock.call(mock_gcs_hook.return_value,
mock_sheet_hook.return_value, r, v)
+ for r, v in zip(RANGES, data)
+ ]
+ mock_upload_data.assert_has_calls(calls)
actual_call_count = mock_upload_data.call_count
assert len(RANGES) == actual_call_count
- mock_xcom.assert_called_once_with(context, "destination_objects",
[PATH, PATH])
+ mock_ti.xcom_push.assert_called_once_with(key="destination_objects",
value=[PATH, PATH])
diff --git a/providers/google/tests/unit/google/cloud/utils/airflow_util.py
b/providers/google/tests/unit/google/cloud/utils/airflow_util.py
index 891a00780b5..3e0b14cb0a5 100644
--- a/providers/google/tests/unit/google/cloud/utils/airflow_util.py
+++ b/providers/google/tests/unit/google/cloud/utils/airflow_util.py
@@ -28,7 +28,7 @@ from airflow.utils import timezone
from airflow.utils.types import DagRunType
if TYPE_CHECKING:
- from airflow.models.baseoperator import BaseOperator
+ from airflow.providers.google.version_compat import BaseOperator
def get_dag_run(dag_id: str = "test_dag_id", run_id: str = "test_dag_id") ->
DagRun:
diff --git
a/providers/google/tests/unit/google/marketing_platform/links/test_analytics_admin.py
b/providers/google/tests/unit/google/marketing_platform/links/test_analytics_admin.py
index bb015c9be24..9a1db76e79b 100644
---
a/providers/google/tests/unit/google/marketing_platform/links/test_analytics_admin.py
+++
b/providers/google/tests/unit/google/marketing_platform/links/test_analytics_admin.py
@@ -55,17 +55,15 @@ class TestGoogleAnalyticsPropertyLink:
assert url == ""
def test_persist(self):
- mock_context = mock.MagicMock()
mock_task_instance = mock.MagicMock()
+ mock_context = {"task_instance": mock_task_instance}
GoogleAnalyticsPropertyLink.persist(
context=mock_context,
- task_instance=mock_task_instance,
property_id=TEST_PROPERTY_ID,
)
mock_task_instance.xcom_push.assert_called_once_with(
- mock_context,
key=GoogleAnalyticsPropertyLink.key,
value={"property_id": TEST_PROPERTY_ID},
)
diff --git
a/providers/google/tests/unit/google/marketing_platform/operators/test_campaign_manager.py
b/providers/google/tests/unit/google/marketing_platform/operators/test_campaign_manager.py
index cb2af61d9d2..f2240b79aa5 100644
---
a/providers/google/tests/unit/google/marketing_platform/operators/test_campaign_manager.py
+++
b/providers/google/tests/unit/google/marketing_platform/operators/test_campaign_manager.py
@@ -102,13 +102,8 @@ class TestGoogleCampaignManagerDownloadReportOperator:
)
@mock.patch("airflow.providers.google.marketing_platform.operators.campaign_manager.GCSHook")
@mock.patch("airflow.providers.google.marketing_platform.operators.campaign_manager.BaseOperator")
- @mock.patch(
- "airflow.providers.google.marketing_platform.operators."
-
"campaign_manager.GoogleCampaignManagerDownloadReportOperator.xcom_push"
- )
def test_execute(
self,
- xcom_mock,
mock_base_op,
gcs_hook_mock,
hook_mock,
@@ -120,6 +115,9 @@ class TestGoogleCampaignManagerDownloadReportOperator:
True,
)
tempfile_mock.NamedTemporaryFile.return_value.__enter__.return_value.name =
TEMP_FILE_NAME
+
+ mock_context = {"task_instance": mock.Mock()}
+
op = GoogleCampaignManagerDownloadReportOperator(
profile_id=PROFILE_ID,
report_id=REPORT_ID,
@@ -129,7 +127,7 @@ class TestGoogleCampaignManagerDownloadReportOperator:
api_version=API_VERSION,
task_id="test_task",
)
- op.execute(context=None)
+ op.execute(context=mock_context)
hook_mock.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
api_version=API_VERSION,
@@ -149,7 +147,9 @@ class TestGoogleCampaignManagerDownloadReportOperator:
filename=TEMP_FILE_NAME,
mime_type="text/csv",
)
- xcom_mock.assert_called_once_with(None, key="report_name",
value=REPORT_NAME + ".gz")
+ mock_context["task_instance"].xcom_push.assert_called_once_with(
+ key="report_name", value=REPORT_NAME + ".gz"
+ )
@pytest.mark.parametrize(
"test_bucket_name",
@@ -214,13 +214,11 @@ class TestGoogleCampaignManagerInsertReportOperator:
"airflow.providers.google.marketing_platform.operators.campaign_manager.GoogleCampaignManagerHook"
)
@mock.patch("airflow.providers.google.marketing_platform.operators.campaign_manager.BaseOperator")
- @mock.patch(
- "airflow.providers.google.marketing_platform.operators."
- "campaign_manager.GoogleCampaignManagerInsertReportOperator.xcom_push"
- )
- def test_execute(self, xcom_mock, mock_base_op, hook_mock):
+ def test_execute(self, mock_base_op, hook_mock):
report = {"report": "test"}
+ mock_context = {"task_instance": mock.Mock()}
+
hook_mock.return_value.insert_report.return_value = {"id": REPORT_ID}
op = GoogleCampaignManagerInsertReportOperator(
@@ -229,14 +227,14 @@ class TestGoogleCampaignManagerInsertReportOperator:
api_version=API_VERSION,
task_id="test_task",
)
- op.execute(context=None)
+ op.execute(context=mock_context)
hook_mock.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
api_version=API_VERSION,
impersonation_chain=None,
)
hook_mock.return_value.insert_report.assert_called_once_with(profile_id=PROFILE_ID,
report=report)
- xcom_mock.assert_called_once_with(None, key="report_id",
value=REPORT_ID)
+
mock_context["task_instance"].xcom_push.assert_called_once_with(key="report_id",
value=REPORT_ID)
def test_prepare_template(self):
report = {"key": "value"}
@@ -260,13 +258,11 @@ class TestGoogleCampaignManagerRunReportOperator:
"airflow.providers.google.marketing_platform.operators.campaign_manager.GoogleCampaignManagerHook"
)
@mock.patch("airflow.providers.google.marketing_platform.operators.campaign_manager.BaseOperator")
- @mock.patch(
- "airflow.providers.google.marketing_platform.operators."
- "campaign_manager.GoogleCampaignManagerRunReportOperator.xcom_push"
- )
- def test_execute(self, xcom_mock, mock_base_op, hook_mock):
+ def test_execute(self, mock_base_op, hook_mock):
synchronous = True
+ mock_context = {"task_instance": mock.Mock()}
+
hook_mock.return_value.run_report.return_value = {"id": FILE_ID}
op = GoogleCampaignManagerRunReportOperator(
@@ -276,7 +272,7 @@ class TestGoogleCampaignManagerRunReportOperator:
api_version=API_VERSION,
task_id="test_task",
)
- op.execute(context=None)
+ op.execute(context=mock_context)
hook_mock.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
api_version=API_VERSION,
@@ -285,7 +281,7 @@ class TestGoogleCampaignManagerRunReportOperator:
hook_mock.return_value.run_report.assert_called_once_with(
profile_id=PROFILE_ID, report_id=REPORT_ID, synchronous=synchronous
)
- xcom_mock.assert_called_once_with(None, key="file_id", value=FILE_ID)
+
mock_context["task_instance"].xcom_push.assert_called_once_with(key="file_id",
value=FILE_ID)
class TestGoogleCampaignManagerBatchInsertConversionsOperator:
diff --git
a/providers/google/tests/unit/google/marketing_platform/operators/test_display_video.py
b/providers/google/tests/unit/google/marketing_platform/operators/test_display_video.py
index 8596743dd7d..db361c80f4a 100644
---
a/providers/google/tests/unit/google/marketing_platform/operators/test_display_video.py
+++
b/providers/google/tests/unit/google/marketing_platform/operators/test_display_video.py
@@ -83,10 +83,6 @@ class TestGoogleDisplayVideo360DownloadReportV2Operator:
@mock.patch("airflow.providers.google.marketing_platform.operators.display_video.shutil")
@mock.patch("airflow.providers.google.marketing_platform.operators.display_video.urllib.request")
@mock.patch("airflow.providers.google.marketing_platform.operators.display_video.tempfile")
- @mock.patch(
- "airflow.providers.google.marketing_platform.operators."
- "display_video.GoogleDisplayVideo360DownloadReportV2Operator.xcom_push"
- )
@mock.patch("airflow.providers.google.marketing_platform.operators.display_video.GCSHook")
@mock.patch(
"airflow.providers.google.marketing_platform.operators.display_video.GoogleDisplayVideo360Hook"
@@ -95,7 +91,6 @@ class TestGoogleDisplayVideo360DownloadReportV2Operator:
self,
mock_hook,
mock_gcs_hook,
- mock_xcom,
mock_temp,
mock_request,
mock_shutil,
@@ -109,6 +104,9 @@ class TestGoogleDisplayVideo360DownloadReportV2Operator:
"googleCloudStoragePath": file_path,
}
}
+ # Create mock context with task_instance
+ mock_context = {"task_instance": mock.Mock()}
+
op = GoogleDisplayVideo360DownloadReportV2Operator(
query_id=QUERY_ID,
report_id=REPORT_ID,
@@ -118,9 +116,9 @@ class TestGoogleDisplayVideo360DownloadReportV2Operator:
)
if should_except:
with pytest.raises(AirflowException):
- op.execute(context=None)
+ op.execute(context=mock_context)
return
- op.execute(context=None)
+ op.execute(context=mock_context)
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
api_version="v2",
@@ -139,7 +137,9 @@ class TestGoogleDisplayVideo360DownloadReportV2Operator:
mime_type="text/csv",
object_name=REPORT_NAME + ".gz",
)
- mock_xcom.assert_called_once_with(None, key="report_name",
value=REPORT_NAME + ".gz")
+ mock_context["task_instance"].xcom_push.assert_called_once_with(
+ key="report_name", value=REPORT_NAME + ".gz"
+ )
@pytest.mark.parametrize(
"test_bucket_name",
@@ -199,15 +199,15 @@ class TestGoogleDisplayVideo360DownloadReportV2Operator:
class TestGoogleDisplayVideo360RunQueryOperator:
- @mock.patch(
- "airflow.providers.google.marketing_platform.operators."
- "display_video.GoogleDisplayVideo360RunQueryOperator.xcom_push"
- )
@mock.patch(
"airflow.providers.google.marketing_platform.operators.display_video.GoogleDisplayVideo360Hook"
)
- def test_execute(self, hook_mock, mock_xcom):
+ def test_execute(self, hook_mock):
parameters = {"param": "test"}
+
+ # Create mock context with task_instance
+ mock_context = {"task_instance": mock.Mock()}
+
hook_mock.return_value.run_query.return_value = {
"key": {
"queryId": QUERY_ID,
@@ -220,15 +220,15 @@ class TestGoogleDisplayVideo360RunQueryOperator:
api_version=API_VERSION,
task_id="test_task",
)
- op.execute(context=None)
+ op.execute(context=mock_context)
hook_mock.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
api_version=API_VERSION,
impersonation_chain=None,
)
- mock_xcom.assert_any_call(None, key="query_id", value=QUERY_ID)
- mock_xcom.assert_any_call(None, key="report_id", value=REPORT_ID)
+
mock_context["task_instance"].xcom_push.assert_any_call(key="query_id",
value=QUERY_ID)
+
mock_context["task_instance"].xcom_push.assert_any_call(key="report_id",
value=REPORT_ID)
hook_mock.return_value.run_query.assert_called_once_with(query_id=QUERY_ID,
params=parameters)
@@ -388,20 +388,20 @@ class TestGoogleDisplayVideo360SDFtoGCSOperator:
class TestGoogleDisplayVideo360CreateSDFDownloadTaskOperator:
- @mock.patch(
- "airflow.providers.google.marketing_platform.operators."
-
"display_video.GoogleDisplayVideo360CreateSDFDownloadTaskOperator.xcom_push"
- )
@mock.patch(
"airflow.providers.google.marketing_platform.operators.display_video.GoogleDisplayVideo360Hook"
)
- def test_execute(self, mock_hook, xcom_mock):
+ def test_execute(self, mock_hook):
body_request = {
"version": "1",
"id": "id",
"filter": {"id": []},
}
test_name = "test_task"
+
+ # Create mock context with task_instance
+ mock_context = {"task_instance": mock.Mock()}
+
mock_hook.return_value.create_sdf_download_operation.return_value =
{"name": test_name}
op = GoogleDisplayVideo360CreateSDFDownloadTaskOperator(
@@ -411,7 +411,7 @@ class
TestGoogleDisplayVideo360CreateSDFDownloadTaskOperator:
task_id="test_task",
)
- op.execute(context=None)
+ op.execute(context=mock_context)
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
api_version=API_VERSION,
@@ -422,29 +422,29 @@ class
TestGoogleDisplayVideo360CreateSDFDownloadTaskOperator:
mock_hook.return_value.create_sdf_download_operation.assert_called_once_with(
body_request=body_request
)
- xcom_mock.assert_called_once_with(None, key="name", value=test_name)
+
mock_context["task_instance"].xcom_push.assert_called_once_with(key="name",
value=test_name)
class TestGoogleDisplayVideo360CreateQueryOperator:
- @mock.patch(
- "airflow.providers.google.marketing_platform.operators."
- "display_video.GoogleDisplayVideo360CreateQueryOperator.xcom_push"
- )
@mock.patch(
"airflow.providers.google.marketing_platform.operators.display_video.GoogleDisplayVideo360Hook"
)
- def test_execute(self, hook_mock, xcom_mock):
+ def test_execute(self, hook_mock):
body = {"body": "test"}
+
+ # Create mock context with task_instance
+ mock_context = {"task_instance": mock.Mock()}
+
hook_mock.return_value.create_query.return_value = {"queryId":
QUERY_ID}
op = GoogleDisplayVideo360CreateQueryOperator(body=body,
task_id="test_task")
- op.execute(context=None)
+ op.execute(context=mock_context)
hook_mock.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
api_version="v2",
impersonation_chain=None,
)
hook_mock.return_value.create_query.assert_called_once_with(query=body)
- xcom_mock.assert_called_once_with(None, key="query_id", value=QUERY_ID)
+
mock_context["task_instance"].xcom_push.assert_called_once_with(key="query_id",
value=QUERY_ID)
def test_prepare_template(self):
body = {"key": "value"}
diff --git a/providers/google/tests/unit/google/suite/operators/test_sheets.py
b/providers/google/tests/unit/google/suite/operators/test_sheets.py
index 6e04a2b9fe5..1d0fc216c03 100644
--- a/providers/google/tests/unit/google/suite/operators/test_sheets.py
+++ b/providers/google/tests/unit/google/suite/operators/test_sheets.py
@@ -27,11 +27,9 @@ SPREADSHEET_ID = "1234567890"
class TestGoogleSheetsCreateSpreadsheet:
@mock.patch("airflow.providers.google.suite.operators.sheets.GSheetsHook")
- @mock.patch(
-
"airflow.providers.google.suite.operators.sheets.GoogleSheetsCreateSpreadsheetOperator.xcom_push"
- )
- def test_execute(self, mock_xcom, mock_hook):
- context = {}
+ def test_execute(self, mock_hook):
+ mock_task_instance = mock.MagicMock()
+ context = {"task_instance": mock_task_instance}
spreadsheet = mock.MagicMock()
mock_hook.return_value.create_spreadsheet.return_value = {
"spreadsheetId": SPREADSHEET_ID,
@@ -44,5 +42,10 @@ class TestGoogleSheetsCreateSpreadsheet:
mock_hook.return_value.create_spreadsheet.assert_called_once_with(spreadsheet=spreadsheet)
+ # Verify xcom_push was called with correct arguments
+ assert mock_task_instance.xcom_push.call_count == 2
+ mock_task_instance.xcom_push.assert_any_call(key="spreadsheet_id",
value=SPREADSHEET_ID)
+ mock_task_instance.xcom_push.assert_any_call(key="spreadsheet_url",
value=SPREADSHEET_URL)
+
assert op_execute_result["spreadsheetId"] == "1234567890"
assert op_execute_result["spreadsheetUrl"] == "https://example/sheets"