This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 054904bb9a Add ability to pass impersonation_chain to BigQuery
triggers (#35629)
054904bb9a is described below
commit 054904bb9a68eb50070a14fe7300cb1e78e2c579
Author: VladaZakharova <[email protected]>
AuthorDate: Wed Nov 15 19:38:56 2023 +0100
Add ability to pass impersonation_chain to BigQuery triggers (#35629)
This PR adds possibiliy to pass impersonation_chain to BigQuery triggers so
that customers can execute triggers in a different project by passing dedicated
SA.
Co-authored-by: Ulada Zakharava <[email protected]>
---
airflow/providers/google/cloud/hooks/bigquery.py | 24 +++++++
.../providers/google/cloud/operators/bigquery.py | 8 ++-
.../providers/google/cloud/triggers/bigquery.py | 84 ++++++++++++++++++----
dev/breeze/README.md | 2 +-
.../google/cloud/triggers/test_bigquery.py | 15 ++++
5 files changed, 116 insertions(+), 17 deletions(-)
diff --git a/airflow/providers/google/cloud/hooks/bigquery.py
b/airflow/providers/google/cloud/hooks/bigquery.py
index e140f4b5d3..802a134765 100644
--- a/airflow/providers/google/cloud/hooks/bigquery.py
+++ b/airflow/providers/google/cloud/hooks/bigquery.py
@@ -3086,6 +3086,18 @@ class BigQueryAsyncHook(GoogleBaseAsyncHook):
sync_hook_class = BigQueryHook
+ def __init__(
+ self,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ):
+ super().__init__(
+ gcp_conn_id=gcp_conn_id,
+ impersonation_chain=impersonation_chain,
+ **kwargs,
+ )
+
async def get_job_instance(
self, project_id: str | None, job_id: str | None, session:
ClientSession
) -> Job:
@@ -3311,6 +3323,18 @@ class BigQueryTableAsyncHook(GoogleBaseAsyncHook):
sync_hook_class = BigQueryHook
+ def __init__(
+ self,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ):
+ super().__init__(
+ gcp_conn_id=gcp_conn_id,
+ impersonation_chain=impersonation_chain,
+ **kwargs,
+ )
+
async def get_table_client(
self, dataset: str, table_id: str, project_id: str, session:
ClientSession
) -> Table_async:
diff --git a/airflow/providers/google/cloud/operators/bigquery.py
b/airflow/providers/google/cloud/operators/bigquery.py
index 33285331d6..91ff370366 100644
--- a/airflow/providers/google/cloud/operators/bigquery.py
+++ b/airflow/providers/google/cloud/operators/bigquery.py
@@ -301,6 +301,7 @@ class BigQueryCheckOperator(_BigQueryDbHookMixin,
SQLCheckOperator):
else:
hook = BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
)
job = self._submit_job(hook, job_id="")
context["ti"].xcom_push(key="job_id", value=job.job_id)
@@ -312,6 +313,7 @@ class BigQueryCheckOperator(_BigQueryDbHookMixin,
SQLCheckOperator):
job_id=job.job_id,
project_id=hook.project_id,
poll_interval=self.poll_interval,
+ impersonation_chain=self.impersonation_chain,
),
method_name="execute_complete",
)
@@ -424,7 +426,7 @@ class BigQueryValueCheckOperator(_BigQueryDbHookMixin,
SQLValueCheckOperator):
if not self.deferrable:
super().execute(context=context)
else:
- hook = BigQueryHook(gcp_conn_id=self.gcp_conn_id)
+ hook = BigQueryHook(gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain)
job = self._submit_job(hook, job_id="")
context["ti"].xcom_push(key="job_id", value=job.job_id)
@@ -439,6 +441,7 @@ class BigQueryValueCheckOperator(_BigQueryDbHookMixin,
SQLValueCheckOperator):
pass_value=self.pass_value,
tolerance=self.tol,
poll_interval=self.poll_interval,
+ impersonation_chain=self.impersonation_chain,
),
method_name="execute_complete",
)
@@ -573,7 +576,7 @@ class BigQueryIntervalCheckOperator(_BigQueryDbHookMixin,
SQLIntervalCheckOperat
if not self.deferrable:
super().execute(context)
else:
- hook = BigQueryHook(gcp_conn_id=self.gcp_conn_id)
+ hook = BigQueryHook(gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain)
self.log.info("Using ratio formula: %s", self.ratio_formula)
self.log.info("Executing SQL check: %s", self.sql1)
@@ -596,6 +599,7 @@ class BigQueryIntervalCheckOperator(_BigQueryDbHookMixin,
SQLIntervalCheckOperat
ratio_formula=self.ratio_formula,
ignore_zero=self.ignore_zero,
poll_interval=self.poll_interval,
+ impersonation_chain=self.impersonation_chain,
),
method_name="execute_complete",
)
diff --git a/airflow/providers/google/cloud/triggers/bigquery.py
b/airflow/providers/google/cloud/triggers/bigquery.py
index 2ccb20c59e..a28eada370 100644
--- a/airflow/providers/google/cloud/triggers/bigquery.py
+++ b/airflow/providers/google/cloud/triggers/bigquery.py
@@ -17,7 +17,7 @@
from __future__ import annotations
import asyncio
-from typing import Any, AsyncIterator, SupportsAbs
+from typing import Any, AsyncIterator, Sequence, SupportsAbs
from aiohttp import ClientSession
from aiohttp.client_exceptions import ClientResponseError
@@ -35,7 +35,15 @@ class BigQueryInsertJobTrigger(BaseTrigger):
:param project_id: Google Cloud Project where the job is running
:param dataset_id: The dataset ID of the requested table. (templated)
:param table_id: The table ID of the requested table. (templated)
- :param poll_interval: polling period in seconds to check for the status
+ :param poll_interval: polling period in seconds to check for the status.
(templated)
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account.
(templated)
"""
def __init__(
@@ -46,6 +54,7 @@ class BigQueryInsertJobTrigger(BaseTrigger):
dataset_id: str | None = None,
table_id: str | None = None,
poll_interval: float = 4.0,
+ impersonation_chain: str | Sequence[str] | None = None,
):
super().__init__()
self.log.info("Using the connection %s .", conn_id)
@@ -56,6 +65,7 @@ class BigQueryInsertJobTrigger(BaseTrigger):
self.project_id = project_id
self.table_id = table_id
self.poll_interval = poll_interval
+ self.impersonation_chain = impersonation_chain
def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes BigQueryInsertJobTrigger arguments and classpath."""
@@ -68,6 +78,7 @@ class BigQueryInsertJobTrigger(BaseTrigger):
"project_id": self.project_id,
"table_id": self.table_id,
"poll_interval": self.poll_interval,
+ "impersonation_chain": self.impersonation_chain,
},
)
@@ -101,7 +112,7 @@ class BigQueryInsertJobTrigger(BaseTrigger):
yield TriggerEvent({"status": "error", "message": str(e)})
def _get_async_hook(self) -> BigQueryAsyncHook:
- return BigQueryAsyncHook(gcp_conn_id=self.conn_id)
+ return BigQueryAsyncHook(gcp_conn_id=self.conn_id,
impersonation_chain=self.impersonation_chain)
class BigQueryCheckTrigger(BigQueryInsertJobTrigger):
@@ -118,6 +129,7 @@ class BigQueryCheckTrigger(BigQueryInsertJobTrigger):
"project_id": self.project_id,
"table_id": self.table_id,
"poll_interval": self.poll_interval,
+ "impersonation_chain": self.impersonation_chain,
},
)
@@ -191,6 +203,7 @@ class BigQueryGetDataTrigger(BigQueryInsertJobTrigger):
"project_id": self.project_id,
"table_id": self.table_id,
"poll_interval": self.poll_interval,
+ "impersonation_chain": self.impersonation_chain,
"as_dict": self.as_dict,
},
)
@@ -240,13 +253,20 @@ class
BigQueryIntervalCheckTrigger(BigQueryInsertJobTrigger):
:param dataset_id: The dataset ID of the requested table. (templated)
:param table: table name
:param metrics_thresholds: dictionary of ratios indexed by metrics
- :param date_filter_column: column name
- :param days_back: number of days between ds and the ds we want to check
- against
- :param ratio_formula: ration formula
- :param ignore_zero: boolean value to consider zero or not
+ :param date_filter_column: column name. (templated)
+ :param days_back: number of days between ds and the ds we want to check
against. (templated)
+ :param ratio_formula: ration formula. (templated)
+ :param ignore_zero: boolean value to consider zero or not. (templated)
:param table_id: The table ID of the requested table. (templated)
- :param poll_interval: polling period in seconds to check for the status
+ :param poll_interval: polling period in seconds to check for the status.
(templated)
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account.
(templated)
"""
def __init__(
@@ -264,6 +284,7 @@ class
BigQueryIntervalCheckTrigger(BigQueryInsertJobTrigger):
dataset_id: str | None = None,
table_id: str | None = None,
poll_interval: float = 4.0,
+ impersonation_chain: str | Sequence[str] | None = None,
):
super().__init__(
conn_id=conn_id,
@@ -272,6 +293,7 @@ class
BigQueryIntervalCheckTrigger(BigQueryInsertJobTrigger):
dataset_id=dataset_id,
table_id=table_id,
poll_interval=poll_interval,
+ impersonation_chain=impersonation_chain,
)
self.conn_id = conn_id
self.first_job_id = first_job_id
@@ -299,6 +321,7 @@ class
BigQueryIntervalCheckTrigger(BigQueryInsertJobTrigger):
"days_back": self.days_back,
"ratio_formula": self.ratio_formula,
"ignore_zero": self.ignore_zero,
+ "impersonation_chain": self.impersonation_chain,
},
)
@@ -386,12 +409,20 @@ class BigQueryValueCheckTrigger(BigQueryInsertJobTrigger):
:param conn_id: Reference to google cloud connection id
:param sql: the sql to be executed
:param pass_value: pass value
- :param job_id: The ID of the job
+ :param job_id: The ID of the job
:param project_id: Google Cloud Project where the job is running
- :param tolerance: certain metrics for tolerance
+ :param tolerance: certain metrics for tolerance. (templated)
:param dataset_id: The dataset ID of the requested table. (templated)
:param table_id: The table ID of the requested table. (templated)
- :param poll_interval: polling period in seconds to check for the status
+ :param poll_interval: polling period in seconds to check for the status.
(templated)
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
"""
def __init__(
@@ -405,6 +436,7 @@ class BigQueryValueCheckTrigger(BigQueryInsertJobTrigger):
dataset_id: str | None = None,
table_id: str | None = None,
poll_interval: float = 4.0,
+ impersonation_chain: str | Sequence[str] | None = None,
):
super().__init__(
conn_id=conn_id,
@@ -413,6 +445,7 @@ class BigQueryValueCheckTrigger(BigQueryInsertJobTrigger):
dataset_id=dataset_id,
table_id=table_id,
poll_interval=poll_interval,
+ impersonation_chain=impersonation_chain,
)
self.sql = sql
self.pass_value = pass_value
@@ -432,6 +465,7 @@ class BigQueryValueCheckTrigger(BigQueryInsertJobTrigger):
"table_id": self.table_id,
"tolerance": self.tolerance,
"poll_interval": self.poll_interval,
+ "impersonation_chain": self.impersonation_chain,
},
)
@@ -473,6 +507,14 @@ class BigQueryTableExistenceTrigger(BaseTrigger):
:param gcp_conn_id: Reference to google cloud connection id
:param hook_params: params for hook
:param poll_interval: polling period in seconds to check for the status
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account.
(templated)
"""
def __init__(
@@ -483,6 +525,7 @@ class BigQueryTableExistenceTrigger(BaseTrigger):
gcp_conn_id: str,
hook_params: dict[str, Any],
poll_interval: float = 4.0,
+ impersonation_chain: str | Sequence[str] | None = None,
):
self.dataset_id = dataset_id
self.project_id = project_id
@@ -490,6 +533,7 @@ class BigQueryTableExistenceTrigger(BaseTrigger):
self.gcp_conn_id: str = gcp_conn_id
self.poll_interval = poll_interval
self.hook_params = hook_params
+ self.impersonation_chain = impersonation_chain
def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes BigQueryTableExistenceTrigger arguments and classpath."""
@@ -502,11 +546,14 @@ class BigQueryTableExistenceTrigger(BaseTrigger):
"gcp_conn_id": self.gcp_conn_id,
"poll_interval": self.poll_interval,
"hook_params": self.hook_params,
+ "impersonation_chain": self.impersonation_chain,
},
)
def _get_async_hook(self) -> BigQueryTableAsyncHook:
- return BigQueryTableAsyncHook(gcp_conn_id=self.gcp_conn_id)
+ return BigQueryTableAsyncHook(
+ gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain
+ )
async def run(self) -> AsyncIterator[TriggerEvent]: # type:
ignore[override]
"""Will run until the table exists in the Google Big Query."""
@@ -561,6 +608,14 @@ class
BigQueryTablePartitionExistenceTrigger(BigQueryTableExistenceTrigger):
:param gcp_conn_id: Reference to google cloud connection id
:param hook_params: params for hook
:param poll_interval: polling period in seconds to check for the status
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account.
(templated)
"""
def __init__(self, partition_id: str, **kwargs):
@@ -578,13 +633,14 @@ class
BigQueryTablePartitionExistenceTrigger(BigQueryTableExistenceTrigger):
"table_id": self.table_id,
"gcp_conn_id": self.gcp_conn_id,
"poll_interval": self.poll_interval,
+ "impersonation_chain": self.impersonation_chain,
"hook_params": self.hook_params,
},
)
async def run(self) -> AsyncIterator[TriggerEvent]: # type:
ignore[override]
"""Will run until the table exists in the Google Big Query."""
- hook = BigQueryAsyncHook(gcp_conn_id=self.gcp_conn_id)
+ hook = BigQueryAsyncHook(gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain)
job_id = None
while True:
if job_id is not None:
diff --git a/dev/breeze/README.md b/dev/breeze/README.md
index 06ddef4a32..bf0658208e 100644
--- a/dev/breeze/README.md
+++ b/dev/breeze/README.md
@@ -66,6 +66,6 @@ PLEASE DO NOT MODIFY THE HASH BELOW! IT IS AUTOMATICALLY
UPDATED BY PRE-COMMIT.
---------------------------------------------------------------------------------------------------------
-Package config hash:
51d9c2ec8af90c2941d58cf28397e9972d31718bc5d74538eb0614ed9418310e7b1d14bb3ee11f4df6e8403390869838217dc641cdb1416a223b7cf69adf1b20
+Package config hash: Missing file
/usr/local/google/home/uladaz/Documents/Project/Airflow/airflow/dev/breeze/setup.py
---------------------------------------------------------------------------------------------------------
diff --git a/tests/providers/google/cloud/triggers/test_bigquery.py
b/tests/providers/google/cloud/triggers/test_bigquery.py
index 1611c8e3e0..ea2e478d04 100644
--- a/tests/providers/google/cloud/triggers/test_bigquery.py
+++ b/tests/providers/google/cloud/triggers/test_bigquery.py
@@ -60,6 +60,7 @@ TEST_DAYS_BACK = -7
TEST_RATIO_FORMULA = "max_over_min"
TEST_IGNORE_ZERO = True
TEST_GCP_CONN_ID = "TEST_GCP_CONN_ID"
+TEST_IMPERSONATION_CHAIN = "TEST_SERVICE_ACCOUNT"
TEST_HOOK_PARAMS: dict[str, Any] = {}
TEST_PARTITION_ID = "1234"
@@ -73,6 +74,7 @@ def insert_job_trigger():
dataset_id=TEST_DATASET_ID,
table_id=TEST_TABLE_ID,
poll_interval=POLLING_PERIOD_SECONDS,
+ impersonation_chain=TEST_IMPERSONATION_CHAIN,
)
@@ -85,6 +87,7 @@ def get_data_trigger():
dataset_id=TEST_DATASET_ID,
table_id=TEST_TABLE_ID,
poll_interval=POLLING_PERIOD_SECONDS,
+ impersonation_chain=TEST_IMPERSONATION_CHAIN,
)
@@ -97,6 +100,7 @@ def table_existence_trigger():
TEST_GCP_CONN_ID,
TEST_HOOK_PARAMS,
POLLING_PERIOD_SECONDS,
+ TEST_IMPERSONATION_CHAIN,
)
@@ -116,6 +120,7 @@ def interval_check_trigger():
dataset_id=TEST_DATASET_ID,
table_id=TEST_TABLE_ID,
poll_interval=POLLING_PERIOD_SECONDS,
+ impersonation_chain=TEST_IMPERSONATION_CHAIN,
)
@@ -128,6 +133,7 @@ def check_trigger():
dataset_id=TEST_DATASET_ID,
table_id=TEST_TABLE_ID,
poll_interval=POLLING_PERIOD_SECONDS,
+ impersonation_chain=TEST_IMPERSONATION_CHAIN,
)
@@ -143,6 +149,7 @@ def value_check_trigger():
table_id=TEST_TABLE_ID,
tolerance=TEST_TOLERANCE,
poll_interval=POLLING_PERIOD_SECONDS,
+ impersonation_chain=TEST_IMPERSONATION_CHAIN,
)
@@ -160,6 +167,7 @@ class TestBigQueryInsertJobTrigger:
"dataset_id": TEST_DATASET_ID,
"table_id": TEST_TABLE_ID,
"poll_interval": POLLING_PERIOD_SECONDS,
+ "impersonation_chain": TEST_IMPERSONATION_CHAIN,
}
@pytest.mark.asyncio
@@ -232,6 +240,7 @@ class TestBigQueryGetDataTrigger:
assert kwargs == {
"as_dict": False,
"conn_id": TEST_CONN_ID,
+ "impersonation_chain": TEST_IMPERSONATION_CHAIN,
"job_id": TEST_JOB_ID,
"dataset_id": TEST_DATASET_ID,
"project_id": TEST_GCP_PROJECT_ID,
@@ -392,6 +401,7 @@ class TestBigQueryCheckTrigger:
assert classpath ==
"airflow.providers.google.cloud.triggers.bigquery.BigQueryCheckTrigger"
assert kwargs == {
"conn_id": TEST_CONN_ID,
+ "impersonation_chain": TEST_IMPERSONATION_CHAIN,
"job_id": TEST_JOB_ID,
"dataset_id": TEST_DATASET_ID,
"project_id": TEST_GCP_PROJECT_ID,
@@ -472,6 +482,7 @@ class TestBigQueryIntervalCheckTrigger:
assert classpath ==
"airflow.providers.google.cloud.triggers.bigquery.BigQueryIntervalCheckTrigger"
assert kwargs == {
"conn_id": TEST_CONN_ID,
+ "impersonation_chain": TEST_IMPERSONATION_CHAIN,
"first_job_id": TEST_FIRST_JOB_ID,
"second_job_id": TEST_SECOND_JOB_ID,
"project_id": TEST_GCP_PROJECT_ID,
@@ -562,6 +573,7 @@ class TestBigQueryValueCheckTrigger:
assert classpath ==
"airflow.providers.google.cloud.triggers.bigquery.BigQueryValueCheckTrigger"
assert kwargs == {
"conn_id": TEST_CONN_ID,
+ "impersonation_chain": TEST_IMPERSONATION_CHAIN,
"pass_value": TEST_PASS_VALUE,
"job_id": TEST_JOB_ID,
"dataset_id": TEST_DATASET_ID,
@@ -659,6 +671,7 @@ class TestBigQueryTableExistenceTrigger:
"project_id": TEST_GCP_PROJECT_ID,
"table_id": TEST_TABLE_ID,
"gcp_conn_id": TEST_GCP_CONN_ID,
+ "impersonation_chain": TEST_IMPERSONATION_CHAIN,
"poll_interval": POLLING_PERIOD_SECONDS,
"hook_params": TEST_HOOK_PARAMS,
}
@@ -777,6 +790,7 @@ class TestBigQueryTablePartitionExistenceTrigger:
partition_id=TEST_PARTITION_ID,
poll_interval=POLLING_PERIOD_SECONDS,
gcp_conn_id=TEST_GCP_CONN_ID,
+ impersonation_chain=TEST_IMPERSONATION_CHAIN,
hook_params={},
)
@@ -791,6 +805,7 @@ class TestBigQueryTablePartitionExistenceTrigger:
"table_id": TEST_TABLE_ID,
"partition_id": TEST_PARTITION_ID,
"gcp_conn_id": TEST_GCP_CONN_ID,
+ "impersonation_chain": TEST_IMPERSONATION_CHAIN,
"poll_interval": POLLING_PERIOD_SECONDS,
"hook_params": TEST_HOOK_PARAMS,
}