This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new d43c804f2b Add optional 'location' parameter to the
BigQueryInsertJobTrigger (#37282)
d43c804f2b is described below
commit d43c804f2bda3bc518682c9b2af94ea30475c879
Author: max <[email protected]>
AuthorDate: Mon Feb 12 11:33:36 2024 +0100
Add optional 'location' parameter to the BigQueryInsertJobTrigger (#37282)
---
airflow/providers/google/cloud/hooks/bigquery.py | 63 ++++++++++++++++++----
.../providers/google/cloud/operators/bigquery.py | 5 ++
.../google/cloud/transfers/bigquery_to_gcs.py | 1 +
.../google/cloud/transfers/gcs_to_bigquery.py | 1 +
.../providers/google/cloud/triggers/bigquery.py | 18 ++++++-
.../providers/google/cloud/hooks/test_bigquery.py | 19 +++----
.../google/cloud/triggers/test_bigquery.py | 35 ++++++------
7 files changed, 103 insertions(+), 39 deletions(-)
diff --git a/airflow/providers/google/cloud/hooks/bigquery.py
b/airflow/providers/google/cloud/hooks/bigquery.py
index 729d629c31..2b252f3c3f 100644
--- a/airflow/providers/google/cloud/hooks/bigquery.py
+++ b/airflow/providers/google/cloud/hooks/bigquery.py
@@ -20,6 +20,7 @@
from __future__ import annotations
+import asyncio
import json
import logging
import re
@@ -3242,16 +3243,58 @@ class BigQueryAsyncHook(GoogleBaseAsyncHook):
session=cast(Session, session),
)
- async def get_job_status(self, job_id: str | None, project_id: str | None
= None) -> dict[str, str]:
- async with ClientSession() as s:
- job_client = await self.get_job_instance(project_id, job_id, s)
- job = await job_client.get_job()
- status = job.get("status", {})
- if status["state"] == "DONE":
- if "errorResult" in status:
- return {"status": "error", "message":
status["errorResult"]["message"]}
- return {"status": "success", "message": "Job completed"}
- return {"status": status["state"].lower(), "message": "Job
running"}
+ async def _get_job(
+ self, job_id: str | None, project_id: str | None = None, location: str
| None = None
+ ) -> CopyJob | QueryJob | LoadJob | ExtractJob | UnknownJob:
+ """
+ Get BigQuery job by its ID, project ID and location.
+
+ WARNING.
+ This is a temporary workaround for issues below, and it's not
intended to be used elsewhere!
+ https://github.com/apache/airflow/issues/35833
+ https://github.com/talkiq/gcloud-aio/issues/584
+
+ This method was developed, because neither the `google-cloud-bigquery`
nor the `gcloud-aio-bigquery`
+ provides asynchronous access to a BigQuery jobs with location
parameter. That's why this method wraps
+ synchronous client call with the event loop's run_in_executor() method.
+
+ This workaround must be deleted along with the method _get_job_sync()
and replaced by more robust and
+ cleaner solution in one of two cases:
+ 1. The `google-cloud-bigquery` library provides async client with
get_job method, that supports
+ optional parameter `location`
+ 2. The `gcloud-aio-bigquery` library supports the `location`
parameter in get_job() method.
+ """
+ loop = asyncio.get_event_loop()
+ job = await loop.run_in_executor(None, self._get_job_sync, job_id,
project_id, location)
+ return job
+
+ def _get_job_sync(self, job_id, project_id, location):
+ """
+ Get BigQuery job by its ID, project ID and location synchronously.
+
+ WARNING
+ This is a temporary workaround for issues below, and it's not
intended to be used elsewhere!
+ https://github.com/apache/airflow/issues/35833
+ https://github.com/talkiq/gcloud-aio/issues/584
+
+ This workaround must be deleted along with the method _get_job() and
replaced by more robust and
+ cleaner solution in one of two cases:
+ 1. The `google-cloud-bigquery` library provides async client with
get_job method, that supports
+ optional parameter `location`
+ 2. The `gcloud-aio-bigquery` library supports the `location`
parameter in get_job() method.
+ """
+ hook = BigQueryHook(**self._hook_kwargs)
+ return hook.get_job(job_id=job_id, project_id=project_id,
location=location)
+
+ async def get_job_status(
+ self, job_id: str | None, project_id: str | None = None, location: str
| None = None
+ ) -> dict[str, str]:
+ job = await self._get_job(job_id=job_id, project_id=project_id,
location=location)
+ if job.state == "DONE":
+ if job.error_result:
+ return {"status": "error", "message":
job.error_result["message"]}
+ return {"status": "success", "message": "Job completed"}
+ return {"status": str(job.state).lower(), "message": "Job running"}
async def get_job_output(
self,
diff --git a/airflow/providers/google/cloud/operators/bigquery.py
b/airflow/providers/google/cloud/operators/bigquery.py
index fc90bbeed9..b391a1508c 100644
--- a/airflow/providers/google/cloud/operators/bigquery.py
+++ b/airflow/providers/google/cloud/operators/bigquery.py
@@ -313,6 +313,7 @@ class BigQueryCheckOperator(_BigQueryDbHookMixin,
SQLCheckOperator):
conn_id=self.gcp_conn_id,
job_id=job.job_id,
project_id=hook.project_id,
+ location=self.location or hook.location,
poll_interval=self.poll_interval,
impersonation_chain=self.impersonation_chain,
),
@@ -438,6 +439,7 @@ class BigQueryValueCheckOperator(_BigQueryDbHookMixin,
SQLValueCheckOperator):
conn_id=self.gcp_conn_id,
job_id=job.job_id,
project_id=hook.project_id,
+ location=self.location or hook.location,
sql=self.sql,
pass_value=self.pass_value,
tolerance=self.tol,
@@ -594,6 +596,7 @@ class BigQueryIntervalCheckOperator(_BigQueryDbHookMixin,
SQLIntervalCheckOperat
second_job_id=job_2.job_id,
project_id=hook.project_id,
table=self.table,
+ location=self.location or hook.location,
metrics_thresholds=self.metrics_thresholds,
date_filter_column=self.date_filter_column,
days_back=self.days_back,
@@ -1068,6 +1071,7 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
dataset_id=self.dataset_id,
table_id=self.table_id,
project_id=self.job_project_id or hook.project_id,
+ location=self.location or hook.location,
poll_interval=self.poll_interval,
as_dict=self.as_dict,
impersonation_chain=self.impersonation_chain,
@@ -2876,6 +2880,7 @@ class BigQueryInsertJobOperator(GoogleCloudBaseOperator,
_BigQueryOpenLineageMix
conn_id=self.gcp_conn_id,
job_id=self.job_id,
project_id=self.project_id,
+ location=self.location or hook.location,
poll_interval=self.poll_interval,
impersonation_chain=self.impersonation_chain,
),
diff --git a/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py
b/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py
index 3ede4db32f..aeb7f46f6e 100644
--- a/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py
+++ b/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py
@@ -261,6 +261,7 @@ class BigQueryToGCSOperator(BaseOperator):
conn_id=self.gcp_conn_id,
job_id=self._job_id,
project_id=self.project_id or self.hook.project_id,
+ location=self.location or self.hook.location,
impersonation_chain=self.impersonation_chain,
),
method_name="execute_complete",
diff --git a/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py
b/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py
index 9d8ce53f4c..4e1d6e0919 100644
--- a/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py
+++ b/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py
@@ -435,6 +435,7 @@ class GCSToBigQueryOperator(BaseOperator):
conn_id=self.gcp_conn_id,
job_id=self.job_id,
project_id=self.project_id or self.hook.project_id,
+ location=self.location or self.hook.location,
impersonation_chain=self.impersonation_chain,
),
method_name="execute_complete",
diff --git a/airflow/providers/google/cloud/triggers/bigquery.py
b/airflow/providers/google/cloud/triggers/bigquery.py
index 3d73e91e98..bc9e812d1b 100644
--- a/airflow/providers/google/cloud/triggers/bigquery.py
+++ b/airflow/providers/google/cloud/triggers/bigquery.py
@@ -33,6 +33,7 @@ class BigQueryInsertJobTrigger(BaseTrigger):
:param conn_id: Reference to google cloud connection id
:param job_id: The ID of the job. It will be suffixed with hash of job
configuration
:param project_id: Google Cloud Project where the job is running
+ :param location: The dataset location.
:param dataset_id: The dataset ID of the requested table. (templated)
:param table_id: The table ID of the requested table. (templated)
:param poll_interval: polling period in seconds to check for the status.
(templated)
@@ -51,6 +52,7 @@ class BigQueryInsertJobTrigger(BaseTrigger):
conn_id: str,
job_id: str | None,
project_id: str | None,
+ location: str | None,
dataset_id: str | None = None,
table_id: str | None = None,
poll_interval: float = 4.0,
@@ -63,6 +65,7 @@ class BigQueryInsertJobTrigger(BaseTrigger):
self._job_conn = None
self.dataset_id = dataset_id
self.project_id = project_id
+ self.location = location
self.table_id = table_id
self.poll_interval = poll_interval
self.impersonation_chain = impersonation_chain
@@ -76,6 +79,7 @@ class BigQueryInsertJobTrigger(BaseTrigger):
"job_id": self.job_id,
"dataset_id": self.dataset_id,
"project_id": self.project_id,
+ "location": self.location,
"table_id": self.table_id,
"poll_interval": self.poll_interval,
"impersonation_chain": self.impersonation_chain,
@@ -87,7 +91,9 @@ class BigQueryInsertJobTrigger(BaseTrigger):
hook = self._get_async_hook()
try:
while True:
- job_status = await hook.get_job_status(job_id=self.job_id,
project_id=self.project_id)
+ job_status = await hook.get_job_status(
+ job_id=self.job_id, project_id=self.project_id,
location=self.location
+ )
if job_status["status"] == "success":
yield TriggerEvent(
{
@@ -127,6 +133,7 @@ class BigQueryCheckTrigger(BigQueryInsertJobTrigger):
"job_id": self.job_id,
"dataset_id": self.dataset_id,
"project_id": self.project_id,
+ "location": self.location,
"table_id": self.table_id,
"poll_interval": self.poll_interval,
"impersonation_chain": self.impersonation_chain,
@@ -201,6 +208,7 @@ class BigQueryGetDataTrigger(BigQueryInsertJobTrigger):
"job_id": self.job_id,
"dataset_id": self.dataset_id,
"project_id": self.project_id,
+ "location": self.location,
"table_id": self.table_id,
"poll_interval": self.poll_interval,
"impersonation_chain": self.impersonation_chain,
@@ -253,6 +261,7 @@ class
BigQueryIntervalCheckTrigger(BigQueryInsertJobTrigger):
:param dataset_id: The dataset ID of the requested table. (templated)
:param table: table name
:param metrics_thresholds: dictionary of ratios indexed by metrics
+ :param location: The dataset location.
:param date_filter_column: column name. (templated)
:param days_back: number of days between ds and the ds we want to check
against. (templated)
:param ratio_formula: ration formula. (templated)
@@ -277,6 +286,7 @@ class
BigQueryIntervalCheckTrigger(BigQueryInsertJobTrigger):
project_id: str | None,
table: str,
metrics_thresholds: dict[str, int],
+ location: str | None = None,
date_filter_column: str | None = "ds",
days_back: SupportsAbs[int] = -7,
ratio_formula: str = "max_over_min",
@@ -290,6 +300,7 @@ class
BigQueryIntervalCheckTrigger(BigQueryInsertJobTrigger):
conn_id=conn_id,
job_id=first_job_id,
project_id=project_id,
+ location=location,
dataset_id=dataset_id,
table_id=table_id,
poll_interval=poll_interval,
@@ -317,6 +328,7 @@ class
BigQueryIntervalCheckTrigger(BigQueryInsertJobTrigger):
"project_id": self.project_id,
"table": self.table,
"metrics_thresholds": self.metrics_thresholds,
+ "location": self.location,
"date_filter_column": self.date_filter_column,
"days_back": self.days_back,
"ratio_formula": self.ratio_formula,
@@ -414,6 +426,7 @@ class BigQueryValueCheckTrigger(BigQueryInsertJobTrigger):
:param tolerance: certain metrics for tolerance. (templated)
:param dataset_id: The dataset ID of the requested table. (templated)
:param table_id: The table ID of the requested table. (templated)
+ :param location: The dataset location
:param poll_interval: polling period in seconds to check for the status.
(templated)
:param impersonation_chain: Optional service account to impersonate using
short-term
credentials, or chained list of accounts required to get the
access_token
@@ -435,6 +448,7 @@ class BigQueryValueCheckTrigger(BigQueryInsertJobTrigger):
tolerance: Any = None,
dataset_id: str | None = None,
table_id: str | None = None,
+ location: str | None = None,
poll_interval: float = 4.0,
impersonation_chain: str | Sequence[str] | None = None,
):
@@ -444,6 +458,7 @@ class BigQueryValueCheckTrigger(BigQueryInsertJobTrigger):
project_id=project_id,
dataset_id=dataset_id,
table_id=table_id,
+ location=location,
poll_interval=poll_interval,
impersonation_chain=impersonation_chain,
)
@@ -464,6 +479,7 @@ class BigQueryValueCheckTrigger(BigQueryInsertJobTrigger):
"sql": self.sql,
"table_id": self.table_id,
"tolerance": self.tolerance,
+ "location": self.location,
"poll_interval": self.poll_interval,
"impersonation_chain": self.impersonation_chain,
},
diff --git a/tests/providers/google/cloud/hooks/test_bigquery.py
b/tests/providers/google/cloud/hooks/test_bigquery.py
index 27db487b61..5ca34b276f 100644
--- a/tests/providers/google/cloud/hooks/test_bigquery.py
+++ b/tests/providers/google/cloud/hooks/test_bigquery.py
@@ -2155,23 +2155,18 @@ class
TestBigQueryAsyncHookMethods(_BigQueryBaseAsyncTestClass):
assert isinstance(result, Job)
@pytest.mark.parametrize(
- "job_status, expected",
+ "job_state, error_result, expected",
[
- ({"status": {"state": "DONE"}}, {"status": "success", "message":
"Job completed"}),
- (
- {"status": {"state": "DONE", "errorResult": {"message":
"Timeout"}}},
- {"status": "error", "message": "Timeout"},
- ),
- ({"status": {"state": "running"}}, {"status": "running",
"message": "Job running"}),
+ ("DONE", None, {"status": "success", "message": "Job completed"}),
+ ("DONE", {"message": "Timeout"}, {"status": "error", "message":
"Timeout"}),
+ ("RUNNING", None, {"status": "running", "message": "Job running"}),
],
)
@pytest.mark.asyncio
-
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_instance")
- async def test_get_job_status(self, mock_job_instance, job_status,
expected):
+
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook._get_job")
+ async def test_get_job_status(self, mock_get_job, job_state, error_result,
expected):
hook = BigQueryAsyncHook()
- mock_job_client = AsyncMock(Job)
- mock_job_instance.return_value = mock_job_client
- mock_job_instance.return_value.get_job.return_value = job_status
+ mock_get_job.return_value = mock.MagicMock(state=job_state,
error_result=error_result)
resp = await hook.get_job_status(job_id=JOB_ID, project_id=PROJECT_ID)
assert resp == expected
diff --git a/tests/providers/google/cloud/triggers/test_bigquery.py
b/tests/providers/google/cloud/triggers/test_bigquery.py
index ea2e478d04..ed4861ca76 100644
--- a/tests/providers/google/cloud/triggers/test_bigquery.py
+++ b/tests/providers/google/cloud/triggers/test_bigquery.py
@@ -24,7 +24,7 @@ from unittest.mock import AsyncMock
import pytest
from aiohttp import ClientResponseError, RequestInfo
-from gcloud.aio.bigquery import Job, Table
+from gcloud.aio.bigquery import Table
from multidict import CIMultiDict
from yarl import URL
@@ -48,6 +48,7 @@ TEST_JOB_ID = "1234"
TEST_GCP_PROJECT_ID = "test-project"
TEST_DATASET_ID = "bq_dataset"
TEST_TABLE_ID = "bq_table"
+TEST_LOCATION = "US"
POLLING_PERIOD_SECONDS = 4.0
TEST_SQL_QUERY = "SELECT count(*) from Any"
TEST_PASS_VALUE = 2
@@ -73,6 +74,7 @@ def insert_job_trigger():
project_id=TEST_GCP_PROJECT_ID,
dataset_id=TEST_DATASET_ID,
table_id=TEST_TABLE_ID,
+ location=TEST_LOCATION,
poll_interval=POLLING_PERIOD_SECONDS,
impersonation_chain=TEST_IMPERSONATION_CHAIN,
)
@@ -86,6 +88,7 @@ def get_data_trigger():
project_id=TEST_GCP_PROJECT_ID,
dataset_id=TEST_DATASET_ID,
table_id=TEST_TABLE_ID,
+ location=None,
poll_interval=POLLING_PERIOD_SECONDS,
impersonation_chain=TEST_IMPERSONATION_CHAIN,
)
@@ -132,6 +135,7 @@ def check_trigger():
project_id=TEST_GCP_PROJECT_ID,
dataset_id=TEST_DATASET_ID,
table_id=TEST_TABLE_ID,
+ location=None,
poll_interval=POLLING_PERIOD_SECONDS,
impersonation_chain=TEST_IMPERSONATION_CHAIN,
)
@@ -166,6 +170,7 @@ class TestBigQueryInsertJobTrigger:
"project_id": TEST_GCP_PROJECT_ID,
"dataset_id": TEST_DATASET_ID,
"table_id": TEST_TABLE_ID,
+ "location": TEST_LOCATION,
"poll_interval": POLLING_PERIOD_SECONDS,
"impersonation_chain": TEST_IMPERSONATION_CHAIN,
}
@@ -185,13 +190,11 @@ class TestBigQueryInsertJobTrigger:
)
@pytest.mark.asyncio
-
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_instance")
- async def test_bigquery_insert_job_trigger_running(self,
mock_job_instance, caplog, insert_job_trigger):
+
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook._get_job")
+ async def test_bigquery_insert_job_trigger_running(self, mock_get_job,
caplog, insert_job_trigger):
"""Test that BigQuery Triggers do not fire while a query is still
running."""
- mock_job_client = AsyncMock(Job)
- mock_job_instance.return_value = mock_job_client
- mock_job_instance.return_value.get_job.return_value = {"status":
{"state": "running"}}
+ mock_get_job.return_value = mock.MagicMock(state="RUNNING")
caplog.set_level(logging.INFO)
task = asyncio.create_task(insert_job_trigger.run().__anext__())
@@ -245,17 +248,16 @@ class TestBigQueryGetDataTrigger:
"dataset_id": TEST_DATASET_ID,
"project_id": TEST_GCP_PROJECT_ID,
"table_id": TEST_TABLE_ID,
+ "location": None,
"poll_interval": POLLING_PERIOD_SECONDS,
}
@pytest.mark.asyncio
-
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_instance")
- async def test_bigquery_get_data_trigger_running(self, mock_job_instance,
caplog, get_data_trigger):
+
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook._get_job")
+ async def test_bigquery_get_data_trigger_running(self, mock_get_job,
caplog, get_data_trigger):
"""Test that BigQuery Triggers do not fire while a query is still
running."""
- mock_job_client = AsyncMock(Job)
- mock_job_instance.return_value = mock_job_client
- mock_job_instance.return_value.get_job.return_value = {"status":
{"state": "running"}}
+ mock_get_job.return_value = mock.MagicMock(state="running")
caplog.set_level(logging.INFO)
task = asyncio.create_task(get_data_trigger.run().__anext__())
@@ -348,13 +350,11 @@ class TestBigQueryGetDataTrigger:
class TestBigQueryCheckTrigger:
@pytest.mark.asyncio
-
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_instance")
- async def test_bigquery_check_trigger_running(self, mock_job_instance,
caplog, check_trigger):
+
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook._get_job")
+ async def test_bigquery_check_trigger_running(self, mock_get_job, caplog,
check_trigger):
"""Test that BigQuery Triggers do not fire while a query is still
running."""
- mock_job_client = AsyncMock(Job)
- mock_job_instance.return_value = mock_job_client
- mock_job_instance.return_value.get_job.return_value = {"status":
{"state": "running"}}
+ mock_get_job.return_value = mock.MagicMock(state="running")
task = asyncio.create_task(check_trigger.run().__anext__())
await asyncio.sleep(0.5)
@@ -406,6 +406,7 @@ class TestBigQueryCheckTrigger:
"dataset_id": TEST_DATASET_ID,
"project_id": TEST_GCP_PROJECT_ID,
"table_id": TEST_TABLE_ID,
+ "location": None,
"poll_interval": POLLING_PERIOD_SECONDS,
}
@@ -487,6 +488,7 @@ class TestBigQueryIntervalCheckTrigger:
"second_job_id": TEST_SECOND_JOB_ID,
"project_id": TEST_GCP_PROJECT_ID,
"table": TEST_TABLE_ID,
+ "location": None,
"metrics_thresholds": TEST_METRIC_THRESHOLDS,
"date_filter_column": TEST_DATE_FILTER_COLUMN,
"days_back": TEST_DAYS_BACK,
@@ -578,6 +580,7 @@ class TestBigQueryValueCheckTrigger:
"job_id": TEST_JOB_ID,
"dataset_id": TEST_DATASET_ID,
"project_id": TEST_GCP_PROJECT_ID,
+ "location": None,
"sql": TEST_SQL_QUERY,
"table_id": TEST_TABLE_ID,
"tolerance": TEST_TOLERANCE,