This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new ce50d3731a Use insert_job in the BigQueryToGCPOpertor and adjust links
(#24416)
ce50d3731a is described below
commit ce50d3731a049047d31d09c6d38a470b84cf57e7
Author: Ćukasz Wyszomirski <[email protected]>
AuthorDate: Wed Jun 15 11:45:08 2022 +0200
Use insert_job in the BigQueryToGCPOpertor and adjust links (#24416)
* Use insert_job in the BigQueryToGCPOpertor and adjust links
---
airflow/providers/google/cloud/hooks/bigquery.py | 91 ++++++++++++++-
airflow/providers/google/cloud/links/bigquery.py | 4 +-
.../providers/google/cloud/operators/bigquery.py | 64 ++++++-----
.../google/cloud/transfers/bigquery_to_gcs.py | 122 ++++++++++++++++++---
.../providers/google/cloud/hooks/test_bigquery.py | 34 +++++-
.../google/cloud/operators/test_bigquery.py | 67 +++--------
.../google/cloud/transfers/test_bigquery_to_gcs.py | 43 ++++++--
7 files changed, 310 insertions(+), 115 deletions(-)
diff --git a/airflow/providers/google/cloud/hooks/bigquery.py
b/airflow/providers/google/cloud/hooks/bigquery.py
index 9edde76f3c..1ae2a4d4ab 100644
--- a/airflow/providers/google/cloud/hooks/bigquery.py
+++ b/airflow/providers/google/cloud/hooks/bigquery.py
@@ -23,7 +23,9 @@ implementation for BigQuery.
import hashlib
import json
import logging
+import re
import time
+import uuid
import warnings
from copy import deepcopy
from datetime import datetime, timedelta
@@ -1698,7 +1700,7 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
f"Please only use one or more of the following options:
{allowed_schema_update_options}"
)
- destination_project, destination_dataset, destination_table =
_split_tablename(
+ destination_project, destination_dataset, destination_table =
self.split_tablename(
table_input=destination_project_dataset_table,
default_project_id=self.project_id,
var_name='destination_project_dataset_table',
@@ -1850,7 +1852,7 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
source_project_dataset_tables_fixup = []
for source_project_dataset_table in source_project_dataset_tables:
- source_project, source_dataset, source_table = _split_tablename(
+ source_project, source_dataset, source_table =
self.split_tablename(
table_input=source_project_dataset_table,
default_project_id=self.project_id,
var_name='source_project_dataset_table',
@@ -1859,7 +1861,7 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
{'projectId': source_project, 'datasetId': source_dataset,
'tableId': source_table}
)
- destination_project, destination_dataset, destination_table =
_split_tablename(
+ destination_project, destination_dataset, destination_table =
self.split_tablename(
table_input=destination_project_dataset_table,
default_project_id=self.project_id
)
configuration = {
@@ -1924,7 +1926,7 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
if not self.project_id:
raise ValueError("The project_id should be set")
- source_project, source_dataset, source_table = _split_tablename(
+ source_project, source_dataset, source_table = self.split_tablename(
table_input=source_project_dataset_table,
default_project_id=self.project_id,
var_name='source_project_dataset_table',
@@ -2092,7 +2094,7 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
)
if destination_dataset_table:
- destination_project, destination_dataset, destination_table =
_split_tablename(
+ destination_project, destination_dataset, destination_table =
self.split_tablename(
table_input=destination_dataset_table,
default_project_id=self.project_id
)
@@ -2180,6 +2182,83 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
self.running_job_id = job.job_id
return job.job_id
+ def generate_job_id(self, job_id, dag_id, task_id, logical_date,
configuration, force_rerun=False):
+ if force_rerun:
+ hash_base = str(uuid.uuid4())
+ else:
+ hash_base = json.dumps(configuration, sort_keys=True)
+
+ uniqueness_suffix = hashlib.md5(hash_base.encode()).hexdigest()
+
+ if job_id:
+ return f"{job_id}_{uniqueness_suffix}"
+
+ exec_date = logical_date.isoformat()
+ job_id = f"airflow_{dag_id}_{task_id}_{exec_date}_{uniqueness_suffix}"
+ return re.sub(r"[:\-+.]", "_", job_id)
+
+ def split_tablename(
+ self, table_input: str, default_project_id: str, var_name:
Optional[str] = None
+ ) -> Tuple[str, str, str]:
+
+ if '.' not in table_input:
+ raise ValueError(f'Expected table name in the format of
<dataset>.<table>. Got: {table_input}')
+
+ if not default_project_id:
+ raise ValueError("INTERNAL: No default project is specified")
+
+ def var_print(var_name):
+ if var_name is None:
+ return ""
+ else:
+ return f"Format exception for {var_name}: "
+
+ if table_input.count('.') + table_input.count(':') > 3:
+ raise Exception(f'{var_print(var_name)}Use either : or . to
specify project got {table_input}')
+ cmpt = table_input.rsplit(':', 1)
+ project_id = None
+ rest = table_input
+ if len(cmpt) == 1:
+ project_id = None
+ rest = cmpt[0]
+ elif len(cmpt) == 2 and cmpt[0].count(':') <= 1:
+ if cmpt[-1].count('.') != 2:
+ project_id = cmpt[0]
+ rest = cmpt[1]
+ else:
+ raise Exception(
+ f'{var_print(var_name)}Expect format of
(<project:)<dataset>.<table>, got {table_input}'
+ )
+
+ cmpt = rest.split('.')
+ if len(cmpt) == 3:
+ if project_id:
+ raise ValueError(f"{var_print(var_name)}Use either : or . to
specify project")
+ project_id = cmpt[0]
+ dataset_id = cmpt[1]
+ table_id = cmpt[2]
+
+ elif len(cmpt) == 2:
+ dataset_id = cmpt[0]
+ table_id = cmpt[1]
+ else:
+ raise Exception(
+ f'{var_print(var_name)} Expect format of
(<project.|<project:)<dataset>.<table>, '
+ f'got {table_input}'
+ )
+
+ if project_id is None:
+ if var_name is not None:
+ self.log.info(
+ 'Project not included in %s: %s; using project "%s"',
+ var_name,
+ table_input,
+ default_project_id,
+ )
+ project_id = default_project_id
+
+ return project_id, dataset_id, table_id
+
class BigQueryConnection:
"""
@@ -2771,7 +2850,7 @@ def _bq_cast(string_field: str, bq_type: str) ->
Union[None, int, float, bool, s
return string_field
-def _split_tablename(
+def split_tablename(
table_input: str, default_project_id: str, var_name: Optional[str] = None
) -> Tuple[str, str, str]:
diff --git a/airflow/providers/google/cloud/links/bigquery.py
b/airflow/providers/google/cloud/links/bigquery.py
index a80818e203..3f814fcf65 100644
--- a/airflow/providers/google/cloud/links/bigquery.py
+++ b/airflow/providers/google/cloud/links/bigquery.py
@@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
"""This module contains Google BigQuery links."""
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Optional
from airflow.models import BaseOperator
from airflow.providers.google.cloud.links.base import BaseGoogleLink
@@ -66,9 +66,9 @@ class BigQueryTableLink(BaseGoogleLink):
def persist(
context: "Context",
task_instance: BaseOperator,
- dataset_id: str,
project_id: str,
table_id: str,
+ dataset_id: Optional[str] = None,
):
task_instance.xcom_push(
context,
diff --git a/airflow/providers/google/cloud/operators/bigquery.py
b/airflow/providers/google/cloud/operators/bigquery.py
index 3533e4d7b6..6bbe50e5b3 100644
--- a/airflow/providers/google/cloud/operators/bigquery.py
+++ b/airflow/providers/google/cloud/operators/bigquery.py
@@ -19,10 +19,7 @@
"""This module contains Google BigQuery operators."""
import enum
-import hashlib
import json
-import re
-import uuid
import warnings
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional,
Sequence, Set, SupportsAbs, Union
@@ -30,7 +27,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, List,
Optional, Sequence,
import attr
from google.api_core.exceptions import Conflict
from google.api_core.retry import Retry
-from google.cloud.bigquery import DEFAULT_RETRY
+from google.cloud.bigquery import DEFAULT_RETRY, CopyJob, ExtractJob, LoadJob,
QueryJob
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator, BaseOperatorLink
@@ -2137,21 +2134,6 @@ class BigQueryInsertJobOperator(BaseOperator):
if job.error_result:
raise AirflowException(f"BigQuery job {job.job_id} failed:
{job.error_result}")
- def _job_id(self, context):
- if self.force_rerun:
- hash_base = str(uuid.uuid4())
- else:
- hash_base = json.dumps(self.configuration, sort_keys=True)
-
- uniqueness_suffix = hashlib.md5(hash_base.encode()).hexdigest()
-
- if self.job_id:
- return f"{self.job_id}_{uniqueness_suffix}"
-
- exec_date = context['logical_date'].isoformat()
- job_id =
f"airflow_{self.dag_id}_{self.task_id}_{exec_date}_{uniqueness_suffix}"
- return re.sub(r"[:\-+.]", "_", job_id)
-
def execute(self, context: Any):
hook = BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
@@ -2160,7 +2142,14 @@ class BigQueryInsertJobOperator(BaseOperator):
)
self.hook = hook
- job_id = self._job_id(context)
+ job_id = hook.generate_job_id(
+ job_id=self.job_id,
+ dag_id=self.dag_id,
+ task_id=self.task_id,
+ logical_date=context["logical_date"],
+ configuration=self.configuration,
+ force_rerun=self.force_rerun,
+ )
try:
self.log.info(f"Executing: {self.configuration}")
@@ -2185,16 +2174,31 @@ class BigQueryInsertJobOperator(BaseOperator):
f"Or, if you want to reattach in this scenario add
{job.state} to `reattach_states`"
)
- if "query" in job.to_api_repr()["configuration"]:
- if "destinationTable" in
job.to_api_repr()["configuration"]["query"]:
- table =
job.to_api_repr()["configuration"]["query"]["destinationTable"]
- BigQueryTableLink.persist(
- context=context,
- task_instance=self,
- dataset_id=table["datasetId"],
- project_id=table["projectId"],
- table_id=table["tableId"],
- )
+ job_types = {
+ LoadJob._JOB_TYPE: ["sourceTable", "destinationTable"],
+ CopyJob._JOB_TYPE: ["sourceTable", "destinationTable"],
+ ExtractJob._JOB_TYPE: ["sourceTable"],
+ QueryJob._JOB_TYPE: ["destinationTable"],
+ }
+
+ if self.project_id:
+ for job_type, tables_prop in job_types.items():
+ job_configuration = job.to_api_repr()["configuration"]
+ if job_type in job_configuration:
+ for table_prop in tables_prop:
+ if table_prop in job_configuration[job_type]:
+ table = job_configuration[job_type][table_prop]
+ persist_kwargs = {
+ "context": context,
+ "task_instance": self,
+ "project_id": self.project_id,
+ "table_id": table,
+ }
+ if not isinstance(table, str):
+ persist_kwargs["table_id"] = table["tableId"]
+ persist_kwargs["dataset_id"] =
table["datasetId"]
+
+ BigQueryTableLink.persist(**persist_kwargs)
self.job_id = job.job_id
return job.job_id
diff --git a/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py
b/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py
index fbae5a5b77..509fe5def4 100644
--- a/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py
+++ b/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py
@@ -16,10 +16,15 @@
# specific language governing permissions and limitations
# under the License.
"""This module contains Google BigQuery to Google Cloud Storage operator."""
-from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Set,
Union
+from google.api_core.exceptions import Conflict
+from google.api_core.retry import Retry
+from google.cloud.bigquery import DEFAULT_RETRY, ExtractJob
+
+from airflow import AirflowException
from airflow.models import BaseOperator
-from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook,
BigQueryJob
+from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
from airflow.providers.google.cloud.links.bigquery import BigQueryTableLink
if TYPE_CHECKING:
@@ -42,6 +47,7 @@ class BigQueryToGCSOperator(BaseOperator):
Storage URI (e.g. gs://some-bucket/some-file.txt). (templated) Follows
convention defined here:
https://cloud.google.com/bigquery/exporting-data-from-bigquery#exportingmultiple
+ :param project_id: Google Cloud Project where the job is running
:param compression: Type of compression to use.
:param export_format: File format to export.
:param field_delimiter: The delimiter to use when extracting to a CSV.
@@ -61,6 +67,16 @@ class BigQueryToGCSOperator(BaseOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding
identity, with first
account from the list granting this role to the originating account
(templated).
+ :param result_retry: How to retry the `result` call that retrieves rows
+ :param result_timeout: The number of seconds to wait for `result` method
before using `result_retry`
+ :param job_id: The ID of the job. It will be suffixed with hash of job
configuration
+ unless ``force_rerun`` is True.
+ The ID must contain only letters (a-z, A-Z), numbers (0-9),
underscores (_), or
+ dashes (-). The maximum length is 1,024 characters. If not provided
then uuid will
+ be generated.
+ :param force_rerun: If True then operator will use hash of uuid as job id
suffix
+ :param reattach_states: Set of BigQuery job's states in case of which we
should reattach
+ to the job. Should be other than final states.
"""
template_fields: Sequence[str] = (
@@ -78,6 +94,7 @@ class BigQueryToGCSOperator(BaseOperator):
*,
source_project_dataset_table: str,
destination_cloud_storage_uris: List[str],
+ project_id: Optional[str] = None,
compression: str = 'NONE',
export_format: str = 'CSV',
field_delimiter: str = ',',
@@ -87,10 +104,15 @@ class BigQueryToGCSOperator(BaseOperator):
labels: Optional[Dict] = None,
location: Optional[str] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
+ result_retry: Retry = DEFAULT_RETRY,
+ result_timeout: Optional[float] = None,
+ job_id: Optional[str] = None,
+ force_rerun: bool = False,
+ reattach_states: Optional[Set[str]] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
-
+ self.project_id = project_id
self.source_project_dataset_table = source_project_dataset_table
self.destination_cloud_storage_uris = destination_cloud_storage_uris
self.compression = compression
@@ -102,6 +124,48 @@ class BigQueryToGCSOperator(BaseOperator):
self.labels = labels
self.location = location
self.impersonation_chain = impersonation_chain
+ self.result_retry = result_retry
+ self.result_timeout = result_timeout
+ self.job_id = job_id
+ self.force_rerun = force_rerun
+ self.reattach_states: Set[str] = reattach_states or set()
+ self.hook: Optional[BigQueryHook] = None
+
+ @staticmethod
+ def _handle_job_error(job: ExtractJob) -> None:
+ if job.error_result:
+ raise AirflowException(f"BigQuery job {job.job_id} failed:
{job.error_result}")
+
+ def _prepare_configuration(self):
+ source_project, source_dataset, source_table =
self.hook.split_tablename(
+ table_input=self.source_project_dataset_table,
+ default_project_id=self.project_id or self.hook.project_id,
+ var_name='source_project_dataset_table',
+ )
+
+ configuration: Dict[str, Any] = {
+ 'extract': {
+ 'sourceTable': {
+ 'projectId': source_project,
+ 'datasetId': source_dataset,
+ 'tableId': source_table,
+ },
+ 'compression': self.compression,
+ 'destinationUris': self.destination_cloud_storage_uris,
+ 'destinationFormat': self.export_format,
+ }
+ }
+
+ if self.labels:
+ configuration['labels'] = self.labels
+
+ if self.export_format == 'CSV':
+ # Only set fieldDelimiter and printHeader fields if using CSV.
+ # Google does not like it if you set these fields for other export
+ # formats.
+ configuration['extract']['fieldDelimiter'] = self.field_delimiter
+ configuration['extract']['printHeader'] = self.print_header
+ return configuration
def execute(self, context: 'Context'):
self.log.info(
@@ -115,17 +179,49 @@ class BigQueryToGCSOperator(BaseOperator):
location=self.location,
impersonation_chain=self.impersonation_chain,
)
- job: BigQueryJob = hook.run_extract(
- source_project_dataset_table=self.source_project_dataset_table,
- destination_cloud_storage_uris=self.destination_cloud_storage_uris,
- compression=self.compression,
- export_format=self.export_format,
- field_delimiter=self.field_delimiter,
- print_header=self.print_header,
- labels=self.labels,
- return_full_job=True,
+ self.hook = hook
+
+ configuration = self._prepare_configuration()
+ job_id = hook.generate_job_id(
+ job_id=self.job_id,
+ dag_id=self.dag_id,
+ task_id=self.task_id,
+ logical_date=context["logical_date"],
+ configuration=configuration,
+ force_rerun=self.force_rerun,
)
- conf = job["configuration"]["extract"]["sourceTable"]
+
+ try:
+ self.log.info("Executing: %s", configuration)
+ job: ExtractJob = hook.insert_job(
+ job_id=job_id,
+ configuration=configuration,
+ project_id=self.project_id,
+ location=self.location,
+ timeout=self.result_timeout,
+ retry=self.result_retry,
+ )
+ self._handle_job_error(job)
+ except Conflict:
+ # If the job already exists retrieve it
+ job = hook.get_job(
+ project_id=self.project_id,
+ location=self.location,
+ job_id=job_id,
+ )
+ if job.state in self.reattach_states:
+ # We are reattaching to a job
+ job.result(timeout=self.result_timeout,
retry=self.result_retry)
+ self._handle_job_error(job)
+ else:
+ # Same job configuration so we need force_rerun
+ raise AirflowException(
+ f"Job with id: {job_id} already exists and is in
{job.state} state. If you "
+ f"want to force rerun it consider setting
`force_rerun=True`."
+ f"Or, if you want to reattach in this scenario add
{job.state} to `reattach_states`"
+ )
+
+ conf = job.to_api_repr()["configuration"]["extract"]["sourceTable"]
dataset_id, project_id, table_id = conf["datasetId"],
conf["projectId"], conf["tableId"]
BigQueryTableLink.persist(
context=context,
diff --git a/tests/providers/google/cloud/hooks/test_bigquery.py
b/tests/providers/google/cloud/hooks/test_bigquery.py
index 9de8333eb4..f143a14c91 100644
--- a/tests/providers/google/cloud/hooks/test_bigquery.py
+++ b/tests/providers/google/cloud/hooks/test_bigquery.py
@@ -19,6 +19,7 @@
import re
import unittest
+from datetime import datetime
from unittest import mock
import pytest
@@ -33,9 +34,9 @@ from airflow.providers.google.cloud.hooks.bigquery import (
BigQueryHook,
_api_resource_configs_duplication_check,
_cleanse_time_partitioning,
- _split_tablename,
_validate_src_fmt_configs,
_validate_value,
+ split_tablename,
)
PROJECT_ID = "bq-project"
@@ -918,11 +919,36 @@ class TestBigQueryHookMethods(_BigQueryBaseTestClass):
def test_dbapi_get_uri(self):
assert self.hook.get_uri().startswith('bigquery://')
+ @mock.patch('airflow.providers.google.cloud.hooks.bigquery.hashlib.md5')
+ @pytest.mark.parametrize(
+ "test_dag_id, expected_job_id",
+ [("test-dag-id-1.1",
"airflow_test_dag_id_1_1_test_job_id_2020_01_23T00_00_00_hash")],
+ ids=["test-dag-id-1.1"],
+ )
+ def test_job_id_validity(self, mock_md5, test_dag_id, expected_job_id):
+ hash_ = "hash"
+ mock_md5.return_value.hexdigest.return_value = hash_
+ configuration = {
+ "query": {
+ "query": "SELECT * FROM any",
+ "useLegacySql": False,
+ }
+ }
+
+ job_id = self.hook.generate_job_id(
+ job_id=None,
+ dag_id=test_dag_id,
+ task_id="test_job_id",
+ logical_date=datetime(2020, 1, 23),
+ configuration=configuration,
+ )
+ assert job_id == expected_job_id
+
class TestBigQueryTableSplitter(unittest.TestCase):
def test_internal_need_default_project(self):
with pytest.raises(Exception, match="INTERNAL: No default project is
specified"):
- _split_tablename("dataset.table", None)
+ split_tablename("dataset.table", None)
@parameterized.expand(
[
@@ -935,7 +961,7 @@ class TestBigQueryTableSplitter(unittest.TestCase):
)
def test_split_tablename(self, project_expected, dataset_expected,
table_expected, table_input):
default_project_id = "project"
- project, dataset, table = _split_tablename(table_input,
default_project_id)
+ project, dataset, table = split_tablename(table_input,
default_project_id)
assert project_expected == project
assert dataset_expected == dataset
assert table_expected == table
@@ -969,7 +995,7 @@ class TestBigQueryTableSplitter(unittest.TestCase):
def test_invalid_syntax(self, table_input, var_name, exception_message):
default_project_id = "project"
with pytest.raises(Exception,
match=exception_message.format(table_input)):
- _split_tablename(table_input, default_project_id, var_name)
+ split_tablename(table_input, default_project_id, var_name)
class TestTableOperations(_BigQueryBaseTestClass):
diff --git a/tests/providers/google/cloud/operators/test_bigquery.py
b/tests/providers/google/cloud/operators/test_bigquery.py
index 82dbbe3aef..9a060d2f6e 100644
--- a/tests/providers/google/cloud/operators/test_bigquery.py
+++ b/tests/providers/google/cloud/operators/test_bigquery.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
import unittest
from unittest import mock
from unittest.mock import MagicMock
@@ -25,7 +24,6 @@ from google.cloud.bigquery import DEFAULT_RETRY
from google.cloud.exceptions import Conflict
from airflow.exceptions import AirflowException
-from airflow.models import DAG
from airflow.providers.google.cloud.operators.bigquery import (
BigQueryCheckOperator,
BigQueryConsoleIndexableLink,
@@ -786,13 +784,11 @@ class TestBigQueryUpsertTableOperator(unittest.TestCase):
class TestBigQueryInsertJobOperator:
-
@mock.patch('airflow.providers.google.cloud.operators.bigquery.hashlib.md5')
@mock.patch('airflow.providers.google.cloud.operators.bigquery.BigQueryHook')
- def test_execute_query_success(self, mock_hook, mock_md5):
+ def test_execute_query_success(self, mock_hook):
job_id = "123456"
hash_ = "hash"
real_job_id = f"{job_id}_{hash_}"
- mock_md5.return_value.hexdigest.return_value = hash_
configuration = {
"query": {
@@ -801,6 +797,7 @@ class TestBigQueryInsertJobOperator:
}
}
mock_hook.return_value.insert_job.return_value =
MagicMock(job_id=real_job_id, error_result=False)
+ mock_hook.return_value.generate_job_id.return_value = real_job_id
op = BigQueryInsertJobOperator(
task_id="insert_query_job",
@@ -822,13 +819,11 @@ class TestBigQueryInsertJobOperator:
assert result == real_job_id
-
@mock.patch('airflow.providers.google.cloud.operators.bigquery.hashlib.md5')
@mock.patch('airflow.providers.google.cloud.operators.bigquery.BigQueryHook')
- def test_execute_copy_success(self, mock_hook, mock_md5):
+ def test_execute_copy_success(self, mock_hook):
job_id = "123456"
hash_ = "hash"
real_job_id = f"{job_id}_{hash_}"
- mock_md5.return_value.hexdigest.return_value = hash_
configuration = {
"copy": {
@@ -841,7 +836,7 @@ class TestBigQueryInsertJobOperator:
"jobReference": "a",
}
mock_hook.return_value.insert_job.return_value =
MagicMock(job_id=real_job_id, error_result=False)
-
+ mock_hook.return_value.generate_job_id.return_value = real_job_id
mock_hook.return_value.insert_job.return_value.to_api_repr.return_value =
mock_configuration
op = BigQueryInsertJobOperator(
@@ -864,13 +859,11 @@ class TestBigQueryInsertJobOperator:
assert result == real_job_id
-
@mock.patch('airflow.providers.google.cloud.operators.bigquery.hashlib.md5')
@mock.patch('airflow.providers.google.cloud.operators.bigquery.BigQueryHook')
- def test_on_kill(self, mock_hook, mock_md5):
+ def test_on_kill(self, mock_hook):
job_id = "123456"
hash_ = "hash"
real_job_id = f"{job_id}_{hash_}"
- mock_md5.return_value.hexdigest.return_value = hash_
configuration = {
"query": {
@@ -879,6 +872,7 @@ class TestBigQueryInsertJobOperator:
}
}
mock_hook.return_value.insert_job.return_value =
MagicMock(job_id=real_job_id, error_result=False)
+ mock_hook.return_value.generate_job_id.return_value = real_job_id
op = BigQueryInsertJobOperator(
task_id="insert_query_job",
@@ -901,13 +895,11 @@ class TestBigQueryInsertJobOperator:
project_id=TEST_GCP_PROJECT_ID,
)
-
@mock.patch('airflow.providers.google.cloud.operators.bigquery.hashlib.md5')
@mock.patch('airflow.providers.google.cloud.operators.bigquery.BigQueryHook')
- def test_execute_failure(self, mock_hook, mock_md5):
+ def test_execute_failure(self, mock_hook):
job_id = "123456"
hash_ = "hash"
real_job_id = f"{job_id}_{hash_}"
- mock_md5.return_value.hexdigest.return_value = hash_
configuration = {
"query": {
@@ -916,6 +908,7 @@ class TestBigQueryInsertJobOperator:
}
}
mock_hook.return_value.insert_job.return_value =
MagicMock(job_id=real_job_id, error_result=True)
+ mock_hook.return_value.generate_job_id.return_value = real_job_id
op = BigQueryInsertJobOperator(
task_id="insert_query_job",
@@ -925,15 +918,13 @@ class TestBigQueryInsertJobOperator:
project_id=TEST_GCP_PROJECT_ID,
)
with pytest.raises(AirflowException):
- op.execute({})
+ op.execute(context=MagicMock())
-
@mock.patch('airflow.providers.google.cloud.operators.bigquery.hashlib.md5')
@mock.patch('airflow.providers.google.cloud.operators.bigquery.BigQueryHook')
- def test_execute_reattach(self, mock_hook, mock_md5):
+ def test_execute_reattach(self, mock_hook):
job_id = "123456"
hash_ = "hash"
real_job_id = f"{job_id}_{hash_}"
- mock_md5.return_value.hexdigest.return_value = hash_
configuration = {
"query": {
@@ -950,6 +941,7 @@ class TestBigQueryInsertJobOperator:
done=lambda: False,
)
mock_hook.return_value.get_job.return_value = job
+ mock_hook.return_value.generate_job_id.return_value = real_job_id
op = BigQueryInsertJobOperator(
task_id="insert_query_job",
@@ -974,14 +966,11 @@ class TestBigQueryInsertJobOperator:
assert result == real_job_id
-
@mock.patch('airflow.providers.google.cloud.operators.bigquery.hashlib.md5')
- @mock.patch('airflow.providers.google.cloud.operators.bigquery.uuid')
@mock.patch('airflow.providers.google.cloud.operators.bigquery.BigQueryHook')
- def test_execute_force_rerun(self, mock_hook, mock_uuid, mock_md5):
+ def test_execute_force_rerun(self, mock_hook):
job_id = "123456"
- hash_ = mock_uuid.uuid4.return_value.encode.return_value
+ hash_ = "hash"
real_job_id = f"{job_id}_{hash_}"
- mock_md5.return_value.hexdigest.return_value = hash_
configuration = {
"query": {
@@ -995,6 +984,7 @@ class TestBigQueryInsertJobOperator:
error_result=False,
)
mock_hook.return_value.insert_job.return_value = job
+ mock_hook.return_value.generate_job_id.return_value = real_job_id
op = BigQueryInsertJobOperator(
task_id="insert_query_job",
@@ -1017,13 +1007,11 @@ class TestBigQueryInsertJobOperator:
assert result == real_job_id
-
@mock.patch('airflow.providers.google.cloud.operators.bigquery.hashlib.md5')
@mock.patch('airflow.providers.google.cloud.operators.bigquery.BigQueryHook')
- def test_execute_no_force_rerun(self, mock_hook, mock_md5):
+ def test_execute_no_force_rerun(self, mock_hook):
job_id = "123456"
hash_ = "hash"
real_job_id = f"{job_id}_{hash_}"
- mock_md5.return_value.hexdigest.return_value = hash_
configuration = {
"query": {
@@ -1033,6 +1021,7 @@ class TestBigQueryInsertJobOperator:
}
mock_hook.return_value.insert_job.return_value.result.side_effect =
Conflict("any")
+ mock_hook.return_value.generate_job_id.return_value = real_job_id
job = MagicMock(
job_id=real_job_id,
error_result=False,
@@ -1051,26 +1040,4 @@ class TestBigQueryInsertJobOperator:
)
# No force rerun
with pytest.raises(AirflowException):
- op.execute({})
-
-
@mock.patch('airflow.providers.google.cloud.operators.bigquery.hashlib.md5')
- @pytest.mark.parametrize(
- "test_dag_id, expected_job_id",
- [("test-dag-id-1.1",
"airflow_test_dag_id_1_1_test_job_id_2020_01_23T00_00_00_00_00_hash")],
- ids=["test-dag-id-1.1"],
- )
- def test_job_id_validity(self, mock_md5, test_dag_id, expected_job_id):
- hash_ = "hash"
- mock_md5.return_value.hexdigest.return_value = hash_
- context = {"logical_date": datetime(2020, 1, 23)}
- configuration = {
- "query": {
- "query": "SELECT * FROM any",
- "useLegacySql": False,
- }
- }
- with DAG(dag_id=test_dag_id, start_date=datetime(2020, 1, 23)):
- op = BigQueryInsertJobOperator(
- task_id="test_job_id", configuration=configuration,
project_id=TEST_GCP_PROJECT_ID
- )
- assert op._job_id(context) == expected_job_id
+ op.execute(context=MagicMock())
diff --git a/tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py
b/tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py
index b627d3672c..ada0af296d 100644
--- a/tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py
+++ b/tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py
@@ -18,6 +18,9 @@
import unittest
from unittest import mock
+from unittest.mock import MagicMock
+
+from google.cloud.bigquery.retry import DEFAULT_RETRY
from airflow.providers.google.cloud.transfers.bigquery_to_gcs import
BigQueryToGCSOperator
@@ -37,6 +40,29 @@ class TestBigQueryToGCSOperator(unittest.TestCase):
field_delimiter = ','
print_header = True
labels = {'k1': 'v1'}
+ job_id = "123456"
+ hash_ = "hash"
+ real_job_id = f"{job_id}_{hash_}"
+
+ expected_configuration = {
+ 'extract': {
+ 'sourceTable': {
+ 'projectId': 'test-project-id',
+ 'datasetId': 'test-dataset',
+ 'tableId': 'test-table-id',
+ },
+ 'compression': 'NONE',
+ 'destinationUris': ['gs://some-bucket/some-file.txt'],
+ 'destinationFormat': 'CSV',
+ 'fieldDelimiter': ',',
+ 'printHeader': True,
+ },
+ 'labels': {'k1': 'v1'},
+ }
+
+ mock_hook.return_value.split_tablename.return_value = (PROJECT_ID,
TEST_DATASET, TEST_TABLE_ID)
+ mock_hook.return_value.generate_job_id.return_value = real_job_id
+ mock_hook.return_value.insert_job.return_value =
MagicMock(job_id="real_job_id", error_result=False)
operator = BigQueryToGCSOperator(
task_id=TASK_ID,
@@ -48,16 +74,13 @@ class TestBigQueryToGCSOperator(unittest.TestCase):
print_header=print_header,
labels=labels,
)
-
operator.execute(context=mock.MagicMock())
- mock_hook.return_value.run_extract.assert_called_once_with(
- source_project_dataset_table=source_project_dataset_table,
- destination_cloud_storage_uris=destination_cloud_storage_uris,
- compression=compression,
- export_format=export_format,
- field_delimiter=field_delimiter,
- print_header=print_header,
- labels=labels,
- return_full_job=True,
+ mock_hook.return_value.insert_job.assert_called_once_with(
+ job_id='123456_hash',
+ configuration=expected_configuration,
+ project_id=None,
+ location=None,
+ timeout=None,
+ retry=DEFAULT_RETRY,
)