This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch v2-0-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 02cb5e1af6c4d1b2823729d3f2801fd9d05bdf43 Author: Kamil BreguĊa <[email protected]> AuthorDate: Mon Jan 11 09:39:19 2021 +0100 Support google-cloud-datacatalog>=3.0.0 (#13534) (cherry picked from commit 947dbb73bba736eb146f33117545a18fc2fd3c09) --- airflow/providers/google/ADDITIONAL_INFO.md | 2 +- .../cloud/example_dags/example_datacatalog.py | 10 +- .../providers/google/cloud/hooks/datacatalog.py | 220 ++++++++++++------- .../google/cloud/operators/datacatalog.py | 47 ++-- setup.py | 2 +- .../google/cloud/hooks/test_datacatalog.py | 237 +++++++++++++-------- .../google/cloud/operators/test_datacatalog.py | 49 +++-- 7 files changed, 357 insertions(+), 210 deletions(-) diff --git a/airflow/providers/google/ADDITIONAL_INFO.md b/airflow/providers/google/ADDITIONAL_INFO.md index eca05df..d80f9e1 100644 --- a/airflow/providers/google/ADDITIONAL_INFO.md +++ b/airflow/providers/google/ADDITIONAL_INFO.md @@ -30,7 +30,7 @@ Details are covered in the UPDATING.md files for each library, but there are som | Library name | Previous constraints | Current constraints | | | --- | --- | --- | --- | | [``google-cloud-bigquery-datatransfer``](https://pypi.org/project/google-cloud-bigquery-datatransfer/) | ``>=0.4.0,<2.0.0`` | ``>=3.0.0,<4.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-bigquery-datatransfer/blob/master/UPGRADING.md) | -| [``google-cloud-datacatalog``](https://pypi.org/project/google-cloud-datacatalog/) | ``>=0.5.0,<0.8`` | ``>=1.0.0,<2.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-datacatalog/blob/master/UPGRADING.md) | +| [``google-cloud-datacatalog``](https://pypi.org/project/google-cloud-datacatalog/) | ``>=0.5.0,<0.8`` | ``>=3.0.0,<4.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-datacatalog/blob/master/UPGRADING.md) | | [``google-cloud-os-login``](https://pypi.org/project/google-cloud-os-login/) | ``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-oslogin/blob/master/UPGRADING.md) | | [``google-cloud-pubsub``](https://pypi.org/project/google-cloud-pubsub/) | ``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-pubsub/blob/master/UPGRADING.md) | | [``google-cloud-kms``](https://pypi.org/project/google-cloud-os-login/) | ``>=1.2.1,<2.0.0`` | ``>=2.0.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-kms/blob/master/UPGRADING.md) | diff --git a/airflow/providers/google/cloud/example_dags/example_datacatalog.py b/airflow/providers/google/cloud/example_dags/example_datacatalog.py index c8597a6..cc4b73a 100644 --- a/airflow/providers/google/cloud/example_dags/example_datacatalog.py +++ b/airflow/providers/google/cloud/example_dags/example_datacatalog.py @@ -19,7 +19,7 @@ """ Example Airflow DAG that interacts with Google Data Catalog service """ -from google.cloud.datacatalog_v1beta1.proto.tags_pb2 import FieldType, TagField, TagTemplateField +from google.cloud.datacatalog_v1beta1 import FieldType, TagField, TagTemplateField from airflow import models from airflow.operators.bash_operator import BashOperator @@ -91,7 +91,7 @@ with models.DAG("example_gcp_datacatalog", start_date=days_ago(1), schedule_inte entry_id=ENTRY_ID, entry={ "display_name": "Wizard", - "type": "FILESET", + "type_": "FILESET", "gcs_fileset_spec": {"file_patterns": ["gs://test-datacatalog/**"]}, }, ) @@ -144,7 +144,7 @@ with models.DAG("example_gcp_datacatalog", start_date=days_ago(1), schedule_inte "display_name": "Awesome Tag Template", "fields": { FIELD_NAME_1: TagTemplateField( - display_name="first-field", type=FieldType(primitive_type="STRING") + display_name="first-field", type_=dict(primitive_type="STRING") ) }, }, @@ -172,7 +172,7 @@ with models.DAG("example_gcp_datacatalog", start_date=days_ago(1), schedule_inte tag_template=TEMPLATE_ID, tag_template_field_id=FIELD_NAME_2, tag_template_field=TagTemplateField( - display_name="second-field", type=FieldType(primitive_type="STRING") + display_name="second-field", type_=FieldType(primitive_type="STRING") ), ) # [END howto_operator_gcp_datacatalog_create_tag_template_field] @@ -305,7 +305,7 @@ with models.DAG("example_gcp_datacatalog", start_date=days_ago(1), schedule_inte # [START howto_operator_gcp_datacatalog_lookup_entry_result] lookup_entry_result = BashOperator( task_id="lookup_entry_result", - bash_command="echo \"{{ task_instance.xcom_pull('lookup_entry')['displayName'] }}\"", + bash_command="echo \"{{ task_instance.xcom_pull('lookup_entry')['display_name'] }}\"", ) # [END howto_operator_gcp_datacatalog_lookup_entry_result] diff --git a/airflow/providers/google/cloud/hooks/datacatalog.py b/airflow/providers/google/cloud/hooks/datacatalog.py index 70b488d..0d6cc75 100644 --- a/airflow/providers/google/cloud/hooks/datacatalog.py +++ b/airflow/providers/google/cloud/hooks/datacatalog.py @@ -18,16 +18,18 @@ from typing import Dict, Optional, Sequence, Tuple, Union from google.api_core.retry import Retry -from google.cloud.datacatalog_v1beta1 import DataCatalogClient -from google.cloud.datacatalog_v1beta1.types import ( +from google.cloud import datacatalog +from google.cloud.datacatalog_v1beta1 import ( + CreateTagRequest, + DataCatalogClient, Entry, EntryGroup, - FieldMask, SearchCatalogRequest, Tag, TagTemplate, TagTemplateField, ) +from google.protobuf.field_mask_pb2 import FieldMask from airflow import AirflowException from airflow.providers.google.common.hooks.base_google import GoogleBaseHook @@ -115,10 +117,13 @@ class CloudDataCatalogHook(GoogleBaseHook): :type metadata: Sequence[Tuple[str, str]] """ client = self.get_conn() - parent = DataCatalogClient.entry_group_path(project_id, location, entry_group) + parent = f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}" self.log.info('Creating a new entry: parent=%s', parent) result = client.create_entry( - parent=parent, entry_id=entry_id, entry=entry, retry=retry, timeout=timeout, metadata=metadata + request={'parent': parent, 'entry_id': entry_id, 'entry': entry}, + retry=retry, + timeout=timeout, + metadata=metadata or (), ) self.log.info('Created a entry: name=%s', result.name) return result @@ -161,16 +166,14 @@ class CloudDataCatalogHook(GoogleBaseHook): :type metadata: Sequence[Tuple[str, str]] """ client = self.get_conn() - parent = DataCatalogClient.location_path(project_id, location) + parent = f"projects/{project_id}/locations/{location}" self.log.info('Creating a new entry group: parent=%s', parent) result = client.create_entry_group( - parent=parent, - entry_group_id=entry_group_id, - entry_group=entry_group, + request={'parent': parent, 'entry_group_id': entry_group_id, 'entry_group': entry_group}, retry=retry, timeout=timeout, - metadata=metadata, + metadata=metadata or (), ) self.log.info('Created a entry group: name=%s', result.name) @@ -218,15 +221,34 @@ class CloudDataCatalogHook(GoogleBaseHook): """ client = self.get_conn() if template_id: - template_path = DataCatalogClient.tag_template_path(project_id, location, template_id) + template_path = f"projects/{project_id}/locations/{location}/tagTemplates/{template_id}" if isinstance(tag, Tag): tag.template = template_path else: tag["template"] = template_path - parent = DataCatalogClient.entry_path(project_id, location, entry_group, entry) + parent = f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}/entries/{entry}" self.log.info('Creating a new tag: parent=%s', parent) - result = client.create_tag(parent=parent, tag=tag, retry=retry, timeout=timeout, metadata=metadata) + # HACK: google-cloud-datacatalog has problems with mapping messages where the value is not a + # primitive type, so we need to convert it manually. + # See: https://github.com/googleapis/python-datacatalog/issues/84 + if isinstance(tag, dict): + tag = Tag( + name=tag.get('name'), + template=tag.get('template'), + template_display_name=tag.get('template_display_name'), + column=tag.get('column'), + fields={ + k: datacatalog.TagField(**v) if isinstance(v, dict) else v + for k, v in tag.get("fields", {}).items() + }, + ) + request = CreateTagRequest( + parent=parent, + tag=tag, + ) + + result = client.create_tag(request=request, retry=retry, timeout=timeout, metadata=metadata or ()) self.log.info('Created a tag: name=%s', result.name) return result @@ -267,17 +289,30 @@ class CloudDataCatalogHook(GoogleBaseHook): :type metadata: Sequence[Tuple[str, str]] """ client = self.get_conn() - parent = DataCatalogClient.location_path(project_id, location) + parent = f"projects/{project_id}/locations/{location}" self.log.info('Creating a new tag template: parent=%s', parent) + # HACK: google-cloud-datacatalog has problems with mapping messages where the value is not a + # primitive type, so we need to convert it manually. + # See: https://github.com/googleapis/python-datacatalog/issues/84 + if isinstance(tag_template, dict): + tag_template = datacatalog.TagTemplate( + name=tag_template.get("name"), + display_name=tag_template.get("display_name"), + fields={ + k: datacatalog.TagTemplateField(**v) if isinstance(v, dict) else v + for k, v in tag_template.get("fields", {}).items() + }, + ) + request = datacatalog.CreateTagTemplateRequest( + parent=parent, tag_template_id=tag_template_id, tag_template=tag_template + ) result = client.create_tag_template( - parent=parent, - tag_template_id=tag_template_id, - tag_template=tag_template, + request=request, retry=retry, timeout=timeout, - metadata=metadata, + metadata=metadata or (), ) self.log.info('Created a tag template: name=%s', result.name) @@ -325,17 +360,19 @@ class CloudDataCatalogHook(GoogleBaseHook): :type metadata: Sequence[Tuple[str, str]] """ client = self.get_conn() - parent = DataCatalogClient.tag_template_path(project_id, location, tag_template) + parent = f"projects/{project_id}/locations/{location}/tagTemplates/{tag_template}" self.log.info('Creating a new tag template field: parent=%s', parent) result = client.create_tag_template_field( - parent=parent, - tag_template_field_id=tag_template_field_id, - tag_template_field=tag_template_field, + request={ + 'parent': parent, + 'tag_template_field_id': tag_template_field_id, + 'tag_template_field': tag_template_field, + }, retry=retry, timeout=timeout, - metadata=metadata, + metadata=metadata or (), ) self.log.info('Created a tag template field: name=%s', result.name) @@ -375,9 +412,9 @@ class CloudDataCatalogHook(GoogleBaseHook): :type metadata: Sequence[Tuple[str, str]] """ client = self.get_conn() - name = DataCatalogClient.entry_path(project_id, location, entry_group, entry) + name = f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}/entries/{entry}" self.log.info('Deleting a entry: name=%s', name) - client.delete_entry(name=name, retry=retry, timeout=timeout, metadata=metadata) + client.delete_entry(request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or ()) self.log.info('Deleted a entry: name=%s', name) @GoogleBaseHook.fallback_to_default_project_id @@ -412,10 +449,12 @@ class CloudDataCatalogHook(GoogleBaseHook): :type metadata: Sequence[Tuple[str, str]] """ client = self.get_conn() - name = DataCatalogClient.entry_group_path(project_id, location, entry_group) + name = f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}" self.log.info('Deleting a entry group: name=%s', name) - client.delete_entry_group(name=name, retry=retry, timeout=timeout, metadata=metadata) + client.delete_entry_group( + request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or () + ) self.log.info('Deleted a entry group: name=%s', name) @GoogleBaseHook.fallback_to_default_project_id @@ -454,10 +493,12 @@ class CloudDataCatalogHook(GoogleBaseHook): :type metadata: Sequence[Tuple[str, str]] """ client = self.get_conn() - name = DataCatalogClient.tag_path(project_id, location, entry_group, entry, tag) + name = ( + f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}/entries/{entry}/tags/{tag}" + ) self.log.info('Deleting a tag: name=%s', name) - client.delete_tag(name=name, retry=retry, timeout=timeout, metadata=metadata) + client.delete_tag(request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or ()) self.log.info('Deleted a tag: name=%s', name) @GoogleBaseHook.fallback_to_default_project_id @@ -495,10 +536,12 @@ class CloudDataCatalogHook(GoogleBaseHook): :type metadata: Sequence[Tuple[str, str]] """ client = self.get_conn() - name = DataCatalogClient.tag_template_path(project_id, location, tag_template) + name = f"projects/{project_id}/locations/{location}/tagTemplates/{tag_template}" self.log.info('Deleting a tag template: name=%s', name) - client.delete_tag_template(name=name, force=force, retry=retry, timeout=timeout, metadata=metadata) + client.delete_tag_template( + request={'name': name, 'force': force}, retry=retry, timeout=timeout, metadata=metadata or () + ) self.log.info('Deleted a tag template: name=%s', name) @GoogleBaseHook.fallback_to_default_project_id @@ -537,11 +580,11 @@ class CloudDataCatalogHook(GoogleBaseHook): :type metadata: Sequence[Tuple[str, str]] """ client = self.get_conn() - name = DataCatalogClient.tag_template_field_path(project_id, location, tag_template, field) + name = f"projects/{project_id}/locations/{location}/tagTemplates/{tag_template}/fields/{field}" self.log.info('Deleting a tag template field: name=%s', name) client.delete_tag_template_field( - name=name, force=force, retry=retry, timeout=timeout, metadata=metadata + request={'name': name, 'force': force}, retry=retry, timeout=timeout, metadata=metadata or () ) self.log.info('Deleted a tag template field: name=%s', name) @@ -578,10 +621,12 @@ class CloudDataCatalogHook(GoogleBaseHook): :type metadata: Sequence[Tuple[str, str]] """ client = self.get_conn() - name = DataCatalogClient.entry_path(project_id, location, entry_group, entry) + name = f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}/entries/{entry}" self.log.info('Getting a entry: name=%s', name) - result = client.get_entry(name=name, retry=retry, timeout=timeout, metadata=metadata) + result = client.get_entry( + request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or () + ) self.log.info('Received a entry: name=%s', result.name) return result @@ -607,8 +652,8 @@ class CloudDataCatalogHook(GoogleBaseHook): :param read_mask: The fields to return. If not set or empty, all fields are returned. If a dict is provided, it must be of the same form as the protobuf message - :class:`~google.cloud.datacatalog_v1beta1.types.FieldMask` - :type read_mask: Union[Dict, google.cloud.datacatalog_v1beta1.types.FieldMask] + :class:`~google.protobuf.field_mask_pb2.FieldMask` + :type read_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask] :param project_id: The ID of the Google Cloud project that owns the entry group. If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. :type project_id: str @@ -622,12 +667,15 @@ class CloudDataCatalogHook(GoogleBaseHook): :type metadata: Sequence[Tuple[str, str]] """ client = self.get_conn() - name = DataCatalogClient.entry_group_path(project_id, location, entry_group) + name = f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}" self.log.info('Getting a entry group: name=%s', name) result = client.get_entry_group( - name=name, read_mask=read_mask, retry=retry, timeout=timeout, metadata=metadata + request={'name': name, 'read_mask': read_mask}, + retry=retry, + timeout=timeout, + metadata=metadata or (), ) self.log.info('Received a entry group: name=%s', result.name) @@ -664,11 +712,13 @@ class CloudDataCatalogHook(GoogleBaseHook): :type metadata: Sequence[Tuple[str, str]] """ client = self.get_conn() - name = DataCatalogClient.tag_template_path(project_id, location, tag_template) + name = f"projects/{project_id}/locations/{location}/tagTemplates/{tag_template}" self.log.info('Getting a tag template: name=%s', name) - result = client.get_tag_template(name=name, retry=retry, timeout=timeout, metadata=metadata) + result = client.get_tag_template( + request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or () + ) self.log.info('Received a tag template: name=%s', result.name) @@ -712,12 +762,15 @@ class CloudDataCatalogHook(GoogleBaseHook): :type metadata: Sequence[Tuple[str, str]] """ client = self.get_conn() - parent = DataCatalogClient.entry_path(project_id, location, entry_group, entry) + parent = f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}/entries/{entry}" self.log.info('Listing tag on entry: entry_name=%s', parent) result = client.list_tags( - parent=parent, page_size=page_size, retry=retry, timeout=timeout, metadata=metadata + request={'parent': parent, 'page_size': page_size}, + retry=retry, + timeout=timeout, + metadata=metadata or (), ) self.log.info('Received tags.') @@ -811,12 +864,18 @@ class CloudDataCatalogHook(GoogleBaseHook): if linked_resource: self.log.info('Getting entry: linked_resource=%s', linked_resource) result = client.lookup_entry( - linked_resource=linked_resource, retry=retry, timeout=timeout, metadata=metadata + request={'linked_resource': linked_resource}, + retry=retry, + timeout=timeout, + metadata=metadata or (), ) else: self.log.info('Getting entry: sql_resource=%s', sql_resource) result = client.lookup_entry( - sql_resource=sql_resource, retry=retry, timeout=timeout, metadata=metadata + request={'sql_resource': sql_resource}, + retry=retry, + timeout=timeout, + metadata=metadata or (), ) self.log.info('Received entry. name=%s', result.name) @@ -860,18 +919,17 @@ class CloudDataCatalogHook(GoogleBaseHook): :type metadata: Sequence[Tuple[str, str]] """ client = self.get_conn() - name = DataCatalogClient.tag_template_field_path(project_id, location, tag_template, field) + name = f"projects/{project_id}/locations/{location}/tagTemplates/{tag_template}/fields/{field}" self.log.info( 'Renaming field: old_name=%s, new_tag_template_field_id=%s', name, new_tag_template_field_id ) result = client.rename_tag_template_field( - name=name, - new_tag_template_field_id=new_tag_template_field_id, + request={'name': name, 'new_tag_template_field_id': new_tag_template_field_id}, retry=retry, timeout=timeout, - metadata=metadata, + metadata=metadata or (), ) self.log.info('Renamed tag template field.') @@ -946,13 +1004,10 @@ class CloudDataCatalogHook(GoogleBaseHook): order_by, ) result = client.search_catalog( - scope=scope, - query=query, - page_size=page_size, - order_by=order_by, + request={'scope': scope, 'query': query, 'page_size': page_size, 'order_by': order_by}, retry=retry, timeout=timeout, - metadata=metadata, + metadata=metadata or (), ) self.log.info('Received items.') @@ -984,8 +1039,8 @@ class CloudDataCatalogHook(GoogleBaseHook): updated. If a dict is provided, it must be of the same form as the protobuf message - :class:`~google.cloud.datacatalog_v1beta1.types.FieldMask` - :type update_mask: Union[Dict, google.cloud.datacatalog_v1beta1.types.FieldMask] + :class:`~google.protobuf.field_mask_pb2.FieldMask` + :type update_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask] :param location: Required. The location of the entry to update. :type location: str :param entry_group: The entry group ID for the entry that is being updated. @@ -1006,7 +1061,9 @@ class CloudDataCatalogHook(GoogleBaseHook): """ client = self.get_conn() if project_id and location and entry_group and entry_id: - full_entry_name = DataCatalogClient.entry_path(project_id, location, entry_group, entry_id) + full_entry_name = ( + f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}/entries/{entry_id}" + ) if isinstance(entry, Entry): entry.name = full_entry_name elif isinstance(entry, dict): @@ -1025,7 +1082,10 @@ class CloudDataCatalogHook(GoogleBaseHook): if isinstance(entry, dict): entry = Entry(**entry) result = client.update_entry( - entry=entry, update_mask=update_mask, retry=retry, timeout=timeout, metadata=metadata + request={'entry': entry, 'update_mask': update_mask}, + retry=retry, + timeout=timeout, + metadata=metadata or (), ) self.log.info('Updated entry.') @@ -1059,7 +1119,7 @@ class CloudDataCatalogHook(GoogleBaseHook): If a dict is provided, it must be of the same form as the protobuf message :class:`~google.cloud.datacatalog_v1beta1.types.FieldMask` - :type update_mask: Union[Dict, google.cloud.datacatalog_v1beta1.types.FieldMask] + :type update_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask] :param location: Required. The location of the tag to rename. :type location: str :param entry_group: The entry group ID for the tag that is being updated. @@ -1082,7 +1142,10 @@ class CloudDataCatalogHook(GoogleBaseHook): """ client = self.get_conn() if project_id and location and entry_group and entry and tag_id: - full_tag_name = DataCatalogClient.tag_path(project_id, location, entry_group, entry, tag_id) + full_tag_name = ( + f"projects/{project_id}/locations/{location}/entryGroups/{entry_group}/entries/{entry}" + f"/tags/{tag_id}" + ) if isinstance(tag, Tag): tag.name = full_tag_name elif isinstance(tag, dict): @@ -1102,7 +1165,10 @@ class CloudDataCatalogHook(GoogleBaseHook): if isinstance(tag, dict): tag = Tag(**tag) result = client.update_tag( - tag=tag, update_mask=update_mask, retry=retry, timeout=timeout, metadata=metadata + request={'tag': tag, 'update_mask': update_mask}, + retry=retry, + timeout=timeout, + metadata=metadata or (), ) self.log.info('Updated tag.') @@ -1137,8 +1203,8 @@ class CloudDataCatalogHook(GoogleBaseHook): If absent or empty, all of the allowed fields above will be updated. If a dict is provided, it must be of the same form as the protobuf message - :class:`~google.cloud.datacatalog_v1beta1.types.FieldMask` - :type update_mask: Union[Dict, google.cloud.datacatalog_v1beta1.types.FieldMask] + :class:`~google.protobuf.field_mask_pb2.FieldMask` + :type update_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask] :param location: Required. The location of the tag template to rename. :type location: str :param tag_template_id: Optional. The tag template ID for the entry that is being updated. @@ -1157,8 +1223,8 @@ class CloudDataCatalogHook(GoogleBaseHook): """ client = self.get_conn() if project_id and location and tag_template: - full_tag_template_name = DataCatalogClient.tag_template_path( - project_id, location, tag_template_id + full_tag_template_name = ( + f"projects/{project_id}/locations/{location}/tagTemplates/{tag_template_id}" ) if isinstance(tag_template, TagTemplate): tag_template.name = full_tag_template_name @@ -1179,11 +1245,10 @@ class CloudDataCatalogHook(GoogleBaseHook): if isinstance(tag_template, dict): tag_template = TagTemplate(**tag_template) result = client.update_tag_template( - tag_template=tag_template, - update_mask=update_mask, + request={'tag_template': tag_template, 'update_mask': update_mask}, retry=retry, timeout=timeout, - metadata=metadata, + metadata=metadata or (), ) self.log.info('Updated tag template.') @@ -1222,8 +1287,8 @@ class CloudDataCatalogHook(GoogleBaseHook): Therefore, enum values can only be added, existing enum values cannot be deleted nor renamed. If a dict is provided, it must be of the same form as the protobuf message - :class:`~google.cloud.datacatalog_v1beta1.types.FieldMask` - :type update_mask: Union[Dict, google.cloud.datacatalog_v1beta1.types.FieldMask] + :class:`~google.protobuf.field_mask_pb2.FieldMask` + :type update_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask] :param tag_template_field_name: Optional. The name of the tag template field to rename. :type tag_template_field_name: str :param location: Optional. The location of the tag to rename. @@ -1246,19 +1311,22 @@ class CloudDataCatalogHook(GoogleBaseHook): """ client = self.get_conn() if project_id and location and tag_template and tag_template_field_id: - tag_template_field_name = DataCatalogClient.tag_template_field_path( - project_id, location, tag_template, tag_template_field_id + tag_template_field_name = ( + f"projects/{project_id}/locations/{location}/tagTemplates/{tag_template}" + f"/fields/{tag_template_field_id}" ) self.log.info("Updating tag template field: name=%s", tag_template_field_name) result = client.update_tag_template_field( - name=tag_template_field_name, - tag_template_field=tag_template_field, - update_mask=update_mask, + request={ + 'name': tag_template_field_name, + 'tag_template_field': tag_template_field, + 'update_mask': update_mask, + }, retry=retry, timeout=timeout, - metadata=metadata, + metadata=metadata or (), ) self.log.info('Updated tag template field.') diff --git a/airflow/providers/google/cloud/operators/datacatalog.py b/airflow/providers/google/cloud/operators/datacatalog.py index 00b2765..4b0da05 100644 --- a/airflow/providers/google/cloud/operators/datacatalog.py +++ b/airflow/providers/google/cloud/operators/datacatalog.py @@ -19,17 +19,16 @@ from typing import Dict, Optional, Sequence, Tuple, Union from google.api_core.exceptions import AlreadyExists, NotFound from google.api_core.retry import Retry -from google.cloud.datacatalog_v1beta1 import DataCatalogClient +from google.cloud.datacatalog_v1beta1 import DataCatalogClient, SearchCatalogResult from google.cloud.datacatalog_v1beta1.types import ( Entry, EntryGroup, - FieldMask, SearchCatalogRequest, Tag, TagTemplate, TagTemplateField, ) -from google.protobuf.json_format import MessageToDict +from google.protobuf.field_mask_pb2 import FieldMask from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.datacatalog import CloudDataCatalogHook @@ -153,7 +152,7 @@ class CloudDataCatalogCreateEntryOperator(BaseOperator): _, _, entry_id = result.name.rpartition("/") self.log.info("Current entry_id ID: %s", entry_id) context["task_instance"].xcom_push(key="entry_id", value=entry_id) - return MessageToDict(result) + return Entry.to_dict(result) class CloudDataCatalogCreateEntryGroupOperator(BaseOperator): @@ -268,7 +267,7 @@ class CloudDataCatalogCreateEntryGroupOperator(BaseOperator): _, _, entry_group_id = result.name.rpartition("/") self.log.info("Current entry group ID: %s", entry_group_id) context["task_instance"].xcom_push(key="entry_group_id", value=entry_group_id) - return MessageToDict(result) + return EntryGroup.to_dict(result) class CloudDataCatalogCreateTagOperator(BaseOperator): @@ -404,7 +403,7 @@ class CloudDataCatalogCreateTagOperator(BaseOperator): _, _, tag_id = tag.name.rpartition("/") self.log.info("Current Tag ID: %s", tag_id) context["task_instance"].xcom_push(key="tag_id", value=tag_id) - return MessageToDict(tag) + return Tag.to_dict(tag) class CloudDataCatalogCreateTagTemplateOperator(BaseOperator): @@ -516,7 +515,7 @@ class CloudDataCatalogCreateTagTemplateOperator(BaseOperator): _, _, tag_template = result.name.rpartition("/") self.log.info("Current Tag ID: %s", tag_template) context["task_instance"].xcom_push(key="tag_template_id", value=tag_template) - return MessageToDict(result) + return TagTemplate.to_dict(result) class CloudDataCatalogCreateTagTemplateFieldOperator(BaseOperator): @@ -638,7 +637,7 @@ class CloudDataCatalogCreateTagTemplateFieldOperator(BaseOperator): self.log.info("Current Tag ID: %s", self.tag_template_field_id) context["task_instance"].xcom_push(key="tag_template_field_id", value=self.tag_template_field_id) - return MessageToDict(result) + return TagTemplateField.to_dict(result) class CloudDataCatalogDeleteEntryOperator(BaseOperator): @@ -1216,7 +1215,7 @@ class CloudDataCatalogGetEntryOperator(BaseOperator): timeout=self.timeout, metadata=self.metadata, ) - return MessageToDict(result) + return Entry.to_dict(result) class CloudDataCatalogGetEntryGroupOperator(BaseOperator): @@ -1234,8 +1233,8 @@ class CloudDataCatalogGetEntryGroupOperator(BaseOperator): :param read_mask: The fields to return. If not set or empty, all fields are returned. If a dict is provided, it must be of the same form as the protobuf message - :class:`~google.cloud.datacatalog_v1beta1.types.FieldMask` - :type read_mask: Union[Dict, google.cloud.datacatalog_v1beta1.types.FieldMask] + :class:`~google.protobuf.field_mask_pb2.FieldMask` + :type read_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask] :param project_id: The ID of the Google Cloud project that owns the entry group. If set to ``None`` or missing, the default project_id from the Google Cloud connection is used. :type project_id: Optional[str] @@ -1312,7 +1311,7 @@ class CloudDataCatalogGetEntryGroupOperator(BaseOperator): timeout=self.timeout, metadata=self.metadata, ) - return MessageToDict(result) + return EntryGroup.to_dict(result) class CloudDataCatalogGetTagTemplateOperator(BaseOperator): @@ -1399,7 +1398,7 @@ class CloudDataCatalogGetTagTemplateOperator(BaseOperator): timeout=self.timeout, metadata=self.metadata, ) - return MessageToDict(result) + return TagTemplate.to_dict(result) class CloudDataCatalogListTagsOperator(BaseOperator): @@ -1501,7 +1500,7 @@ class CloudDataCatalogListTagsOperator(BaseOperator): timeout=self.timeout, metadata=self.metadata, ) - return [MessageToDict(item) for item in result] + return [Tag.to_dict(item) for item in result] class CloudDataCatalogLookupEntryOperator(BaseOperator): @@ -1589,7 +1588,7 @@ class CloudDataCatalogLookupEntryOperator(BaseOperator): timeout=self.timeout, metadata=self.metadata, ) - return MessageToDict(result) + return Entry.to_dict(result) class CloudDataCatalogRenameTagTemplateFieldOperator(BaseOperator): @@ -1809,7 +1808,7 @@ class CloudDataCatalogSearchCatalogOperator(BaseOperator): timeout=self.timeout, metadata=self.metadata, ) - return [MessageToDict(item) for item in result] + return [SearchCatalogResult.to_dict(item) for item in result] class CloudDataCatalogUpdateEntryOperator(BaseOperator): @@ -1829,8 +1828,8 @@ class CloudDataCatalogUpdateEntryOperator(BaseOperator): updated. If a dict is provided, it must be of the same form as the protobuf message - :class:`~google.cloud.datacatalog_v1beta1.types.FieldMask` - :type update_mask: Union[Dict, google.cloud.datacatalog_v1beta1.types.FieldMask] + :class:`~google.protobuf.field_mask_pb2.FieldMask` + :type update_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask] :param location: Required. The location of the entry to update. :type location: str :param entry_group: The entry group ID for the entry that is being updated. @@ -1940,8 +1939,8 @@ class CloudDataCatalogUpdateTagOperator(BaseOperator): updated. Currently the only modifiable field is the field ``fields``. If a dict is provided, it must be of the same form as the protobuf message - :class:`~google.cloud.datacatalog_v1beta1.types.FieldMask` - :type update_mask: Union[Dict, google.cloud.datacatalog_v1beta1.types.FieldMask] + :class:`~google.protobuf.field_mask_pb2.FieldMask` + :type update_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask] :param location: Required. The location of the tag to rename. :type location: str :param entry_group: The entry group ID for the tag that is being updated. @@ -2060,8 +2059,8 @@ class CloudDataCatalogUpdateTagTemplateOperator(BaseOperator): If absent or empty, all of the allowed fields above will be updated. If a dict is provided, it must be of the same form as the protobuf message - :class:`~google.cloud.datacatalog_v1beta1.types.FieldMask` - :type update_mask: Union[Dict, google.cloud.datacatalog_v1beta1.types.FieldMask] + :class:`~google.protobuf.field_mask_pb2.FieldMask` + :type update_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask] :param location: Required. The location of the tag template to rename. :type location: str :param tag_template_id: Optional. The tag template ID for the entry that is being updated. @@ -2172,8 +2171,8 @@ class CloudDataCatalogUpdateTagTemplateFieldOperator(BaseOperator): Therefore, enum values can only be added, existing enum values cannot be deleted nor renamed. If a dict is provided, it must be of the same form as the protobuf message - :class:`~google.cloud.datacatalog_v1beta1.types.FieldMask` - :type update_mask: Union[Dict, google.cloud.datacatalog_v1beta1.types.FieldMask] + :class:`~google.protobuf.field_mask_pb2.FieldMask` + :type update_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask] :param tag_template_field_name: Optional. The name of the tag template field to rename. :type tag_template_field_name: str :param location: Optional. The location of the tag to rename. diff --git a/setup.py b/setup.py index 75f5db5..5314814 100644 --- a/setup.py +++ b/setup.py @@ -287,7 +287,7 @@ google = [ 'google-cloud-bigquery-datatransfer>=3.0.0,<4.0.0', 'google-cloud-bigtable>=1.0.0,<2.0.0', 'google-cloud-container>=0.1.1,<2.0.0', - 'google-cloud-datacatalog>=1.0.0,<2.0.0', + 'google-cloud-datacatalog>=3.0.0,<4.0.0', 'google-cloud-dataproc>=1.0.1,<2.0.0', 'google-cloud-dlp>=0.11.0,<2.0.0', 'google-cloud-kms>=2.0.0,<3.0.0', diff --git a/tests/providers/google/cloud/hooks/test_datacatalog.py b/tests/providers/google/cloud/hooks/test_datacatalog.py index f5192c5..99d785f 100644 --- a/tests/providers/google/cloud/hooks/test_datacatalog.py +++ b/tests/providers/google/cloud/hooks/test_datacatalog.py @@ -22,6 +22,7 @@ from unittest import TestCase, mock import pytest from google.api_core.retry import Retry +from google.cloud.datacatalog_v1beta1 import CreateTagRequest, CreateTagTemplateRequest from google.cloud.datacatalog_v1beta1.types import Entry, Tag, TagTemplate from airflow import AirflowException @@ -38,7 +39,7 @@ TEST_ENTRY_ID: str = "test-entry-id" TEST_ENTRY: Dict = {} TEST_RETRY: Retry = Retry() TEST_TIMEOUT: float = 4 -TEST_METADATA: Sequence[Tuple[str, str]] = [] +TEST_METADATA: Sequence[Tuple[str, str]] = () TEST_ENTRY_GROUP_ID: str = "test-entry-group-id" TEST_ENTRY_GROUP: Dict = {} TEST_TAG: Dict = {} @@ -102,7 +103,7 @@ class TestCloudDataCatalog(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.lookup_entry.assert_called_once_with( - linked_resource=TEST_LINKED_RESOURCE, + request=dict(linked_resource=TEST_LINKED_RESOURCE), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -118,7 +119,10 @@ class TestCloudDataCatalog(TestCase): sql_resource=TEST_SQL_RESOURCE, retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA ) mock_get_conn.return_value.lookup_entry.assert_called_once_with( - sql_resource=TEST_SQL_RESOURCE, retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA + request=dict(sql_resource=TEST_SQL_RESOURCE), + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, ) @mock.patch( @@ -148,10 +152,9 @@ class TestCloudDataCatalog(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.search_catalog.assert_called_once_with( - scope=TEST_SCOPE, - query=TEST_QUERY, - page_size=TEST_PAGE_SIZE, - order_by=TEST_ORDER_BY, + request=dict( + scope=TEST_SCOPE, query=TEST_QUERY, page_size=TEST_PAGE_SIZE, order_by=TEST_ORDER_BY + ), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -184,9 +187,11 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.create_entry.assert_called_once_with( - parent=TEST_ENTRY_GROUP_PATH.format(TEST_PROJECT_ID_1), - entry_id=TEST_ENTRY_ID, - entry=TEST_ENTRY, + request=dict( + parent=TEST_ENTRY_GROUP_PATH.format(TEST_PROJECT_ID_1), + entry_id=TEST_ENTRY_ID, + entry=TEST_ENTRY, + ), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -207,9 +212,11 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.create_entry_group.assert_called_once_with( - parent=TEST_LOCATION_PATH.format(TEST_PROJECT_ID_1), - entry_group_id=TEST_ENTRY_GROUP_ID, - entry_group=TEST_ENTRY_GROUP, + request=dict( + parent=TEST_LOCATION_PATH.format(TEST_PROJECT_ID_1), + entry_group_id=TEST_ENTRY_GROUP_ID, + entry_group=TEST_ENTRY_GROUP, + ), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -232,8 +239,10 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.create_tag.assert_called_once_with( - parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1), - tag={"template": TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_1)}, + request=CreateTagRequest( + parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1), + tag=Tag(template=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_1)), + ), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -256,8 +265,10 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.create_tag.assert_called_once_with( - parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1), - tag=Tag(template=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_1)), + request=CreateTagRequest( + parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1), + tag=Tag(template=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_1)), + ), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -278,9 +289,11 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.create_tag_template.assert_called_once_with( - parent=TEST_LOCATION_PATH.format(TEST_PROJECT_ID_1), - tag_template_id=TEST_TAG_TEMPLATE_ID, - tag_template=TEST_TAG_TEMPLATE, + request=CreateTagTemplateRequest( + parent=TEST_LOCATION_PATH.format(TEST_PROJECT_ID_1), + tag_template_id=TEST_TAG_TEMPLATE_ID, + tag_template=TEST_TAG_TEMPLATE, + ), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -302,9 +315,11 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.create_tag_template_field.assert_called_once_with( - parent=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_1), - tag_template_field_id=TEST_TAG_TEMPLATE_FIELD_ID, - tag_template_field=TEST_TAG_TEMPLATE_FIELD, + request=dict( + parent=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_1), + tag_template_field_id=TEST_TAG_TEMPLATE_FIELD_ID, + tag_template_field=TEST_TAG_TEMPLATE_FIELD, + ), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -325,7 +340,9 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.delete_entry.assert_called_once_with( - name=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1), + request=dict( + name=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1), + ), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -345,7 +362,9 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.delete_entry_group.assert_called_once_with( - name=TEST_ENTRY_GROUP_PATH.format(TEST_PROJECT_ID_1), + request=dict( + name=TEST_ENTRY_GROUP_PATH.format(TEST_PROJECT_ID_1), + ), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -367,7 +386,9 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.delete_tag.assert_called_once_with( - name=TEST_TAG_PATH.format(TEST_PROJECT_ID_1), + request=dict( + name=TEST_TAG_PATH.format(TEST_PROJECT_ID_1), + ), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -388,8 +409,7 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.delete_tag_template.assert_called_once_with( - name=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_1), - force=TEST_FORCE, + request=dict(name=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_1), force=TEST_FORCE), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -411,8 +431,10 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.delete_tag_template_field.assert_called_once_with( - name=TEST_TAG_TEMPLATE_FIELD_PATH.format(TEST_PROJECT_ID_1), - force=TEST_FORCE, + request=dict( + name=TEST_TAG_TEMPLATE_FIELD_PATH.format(TEST_PROJECT_ID_1), + force=TEST_FORCE, + ), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -433,7 +455,9 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.get_entry.assert_called_once_with( - name=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1), + request=dict( + name=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1), + ), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -454,8 +478,10 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.get_entry_group.assert_called_once_with( - name=TEST_ENTRY_GROUP_PATH.format(TEST_PROJECT_ID_1), - read_mask=TEST_READ_MASK, + request=dict( + name=TEST_ENTRY_GROUP_PATH.format(TEST_PROJECT_ID_1), + read_mask=TEST_READ_MASK, + ), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -475,7 +501,9 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.get_tag_template.assert_called_once_with( - name=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_1), + request=dict( + name=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_1), + ), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -497,8 +525,10 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.list_tags.assert_called_once_with( - parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1), - page_size=TEST_PAGE_SIZE, + request=dict( + parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1), + page_size=TEST_PAGE_SIZE, + ), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -524,8 +554,10 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.list_tags.assert_called_once_with( - parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1), - page_size=100, + request=dict( + parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1), + page_size=100, + ), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -548,8 +580,10 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.rename_tag_template_field.assert_called_once_with( - name=TEST_TAG_TEMPLATE_FIELD_PATH.format(TEST_PROJECT_ID_1), - new_tag_template_field_id=TEST_NEW_TAG_TEMPLATE_FIELD_ID, + request=dict( + name=TEST_TAG_TEMPLATE_FIELD_PATH.format(TEST_PROJECT_ID_1), + new_tag_template_field_id=TEST_NEW_TAG_TEMPLATE_FIELD_ID, + ), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -572,8 +606,10 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.update_entry.assert_called_once_with( - entry=Entry(name=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1)), - update_mask=TEST_UPDATE_MASK, + request=dict( + entry=Entry(name=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_1)), + update_mask=TEST_UPDATE_MASK, + ), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -597,8 +633,7 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.update_tag.assert_called_once_with( - tag=Tag(name=TEST_TAG_PATH.format(TEST_PROJECT_ID_1)), - update_mask=TEST_UPDATE_MASK, + request=dict(tag=Tag(name=TEST_TAG_PATH.format(TEST_PROJECT_ID_1)), update_mask=TEST_UPDATE_MASK), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -620,8 +655,10 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.update_tag_template.assert_called_once_with( - tag_template=TagTemplate(name=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_1)), - update_mask=TEST_UPDATE_MASK, + request=dict( + tag_template=TagTemplate(name=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_1)), + update_mask=TEST_UPDATE_MASK, + ), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -644,9 +681,11 @@ class TestCloudDataCatalogWithDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.update_tag_template_field.assert_called_once_with( - name=TEST_TAG_TEMPLATE_FIELD_PATH.format(TEST_PROJECT_ID_1), - tag_template_field=TEST_TAG_TEMPLATE_FIELD, - update_mask=TEST_UPDATE_MASK, + request=dict( + name=TEST_TAG_TEMPLATE_FIELD_PATH.format(TEST_PROJECT_ID_1), + tag_template_field=TEST_TAG_TEMPLATE_FIELD, + update_mask=TEST_UPDATE_MASK, + ), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -680,9 +719,11 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.create_entry.assert_called_once_with( - parent=TEST_ENTRY_GROUP_PATH.format(TEST_PROJECT_ID_2), - entry_id=TEST_ENTRY_ID, - entry=TEST_ENTRY, + request=dict( + parent=TEST_ENTRY_GROUP_PATH.format(TEST_PROJECT_ID_2), + entry_id=TEST_ENTRY_ID, + entry=TEST_ENTRY, + ), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -704,9 +745,11 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.create_entry_group.assert_called_once_with( - parent=TEST_LOCATION_PATH.format(TEST_PROJECT_ID_2), - entry_group_id=TEST_ENTRY_GROUP_ID, - entry_group=TEST_ENTRY_GROUP, + request=dict( + parent=TEST_LOCATION_PATH.format(TEST_PROJECT_ID_2), + entry_group_id=TEST_ENTRY_GROUP_ID, + entry_group=TEST_ENTRY_GROUP, + ), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -730,8 +773,10 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.create_tag.assert_called_once_with( - parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2), - tag={"template": TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_2)}, + request=CreateTagRequest( + parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2), + tag=Tag(template=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_2)), + ), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -755,8 +800,10 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.create_tag.assert_called_once_with( - parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2), - tag=Tag(template=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_2)), + request=CreateTagRequest( + parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2), + tag=Tag(template=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_2)), + ), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -778,9 +825,11 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.create_tag_template.assert_called_once_with( - parent=TEST_LOCATION_PATH.format(TEST_PROJECT_ID_2), - tag_template_id=TEST_TAG_TEMPLATE_ID, - tag_template=TEST_TAG_TEMPLATE, + request=CreateTagTemplateRequest( + parent=TEST_LOCATION_PATH.format(TEST_PROJECT_ID_2), + tag_template_id=TEST_TAG_TEMPLATE_ID, + tag_template=TEST_TAG_TEMPLATE, + ), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -803,9 +852,11 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.create_tag_template_field.assert_called_once_with( - parent=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_2), - tag_template_field_id=TEST_TAG_TEMPLATE_FIELD_ID, - tag_template_field=TEST_TAG_TEMPLATE_FIELD, + request=dict( + parent=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_2), + tag_template_field_id=TEST_TAG_TEMPLATE_FIELD_ID, + tag_template_field=TEST_TAG_TEMPLATE_FIELD, + ), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -827,7 +878,7 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.delete_entry.assert_called_once_with( - name=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2), + request=dict(name=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2)), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -848,7 +899,7 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.delete_entry_group.assert_called_once_with( - name=TEST_ENTRY_GROUP_PATH.format(TEST_PROJECT_ID_2), + request=dict(name=TEST_ENTRY_GROUP_PATH.format(TEST_PROJECT_ID_2)), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -871,7 +922,7 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.delete_tag.assert_called_once_with( - name=TEST_TAG_PATH.format(TEST_PROJECT_ID_2), + request=dict(name=TEST_TAG_PATH.format(TEST_PROJECT_ID_2)), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -893,8 +944,7 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.delete_tag_template.assert_called_once_with( - name=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_2), - force=TEST_FORCE, + request=dict(name=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_2), force=TEST_FORCE), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -917,8 +967,7 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.delete_tag_template_field.assert_called_once_with( - name=TEST_TAG_TEMPLATE_FIELD_PATH.format(TEST_PROJECT_ID_2), - force=TEST_FORCE, + request=dict(name=TEST_TAG_TEMPLATE_FIELD_PATH.format(TEST_PROJECT_ID_2), force=TEST_FORCE), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -940,7 +989,7 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.get_entry.assert_called_once_with( - name=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2), + request=dict(name=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2)), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -962,8 +1011,10 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.get_entry_group.assert_called_once_with( - name=TEST_ENTRY_GROUP_PATH.format(TEST_PROJECT_ID_2), - read_mask=TEST_READ_MASK, + request=dict( + name=TEST_ENTRY_GROUP_PATH.format(TEST_PROJECT_ID_2), + read_mask=TEST_READ_MASK, + ), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -984,7 +1035,7 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.get_tag_template.assert_called_once_with( - name=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_2), + request=dict(name=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_2)), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -1007,8 +1058,7 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.list_tags.assert_called_once_with( - parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2), - page_size=TEST_PAGE_SIZE, + request=dict(parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2), page_size=TEST_PAGE_SIZE), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -1035,8 +1085,7 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.list_tags.assert_called_once_with( - parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2), - page_size=100, + request=dict(parent=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2), page_size=100), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -1060,8 +1109,10 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.rename_tag_template_field.assert_called_once_with( - name=TEST_TAG_TEMPLATE_FIELD_PATH.format(TEST_PROJECT_ID_2), - new_tag_template_field_id=TEST_NEW_TAG_TEMPLATE_FIELD_ID, + request=dict( + name=TEST_TAG_TEMPLATE_FIELD_PATH.format(TEST_PROJECT_ID_2), + new_tag_template_field_id=TEST_NEW_TAG_TEMPLATE_FIELD_ID, + ), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -1085,8 +1136,9 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.update_entry.assert_called_once_with( - entry=Entry(name=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2)), - update_mask=TEST_UPDATE_MASK, + request=dict( + entry=Entry(name=TEST_ENTRY_PATH.format(TEST_PROJECT_ID_2)), update_mask=TEST_UPDATE_MASK + ), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -1111,8 +1163,7 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.update_tag.assert_called_once_with( - tag=Tag(name=TEST_TAG_PATH.format(TEST_PROJECT_ID_2)), - update_mask=TEST_UPDATE_MASK, + request=dict(tag=Tag(name=TEST_TAG_PATH.format(TEST_PROJECT_ID_2)), update_mask=TEST_UPDATE_MASK), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -1135,8 +1186,10 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.update_tag_template.assert_called_once_with( - tag_template=TagTemplate(name=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_2)), - update_mask=TEST_UPDATE_MASK, + request=dict( + tag_template=TagTemplate(name=TEST_TAG_TEMPLATE_PATH.format(TEST_PROJECT_ID_2)), + update_mask=TEST_UPDATE_MASK, + ), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, @@ -1160,9 +1213,11 @@ class TestCloudDataCatalogWithoutDefaultProjectIdHook(TestCase): metadata=TEST_METADATA, ) mock_get_conn.return_value.update_tag_template_field.assert_called_once_with( - name=TEST_TAG_TEMPLATE_FIELD_PATH.format(TEST_PROJECT_ID_2), - tag_template_field=TEST_TAG_TEMPLATE_FIELD, - update_mask=TEST_UPDATE_MASK, + request=dict( + name=TEST_TAG_TEMPLATE_FIELD_PATH.format(TEST_PROJECT_ID_2), + tag_template_field=TEST_TAG_TEMPLATE_FIELD, + update_mask=TEST_UPDATE_MASK, + ), retry=TEST_RETRY, timeout=TEST_TIMEOUT, metadata=TEST_METADATA, diff --git a/tests/providers/google/cloud/operators/test_datacatalog.py b/tests/providers/google/cloud/operators/test_datacatalog.py index b575dd4..517b35c 100644 --- a/tests/providers/google/cloud/operators/test_datacatalog.py +++ b/tests/providers/google/cloud/operators/test_datacatalog.py @@ -87,15 +87,25 @@ TEST_TAG_PATH: str = ( ) TEST_ENTRY: Entry = Entry(name=TEST_ENTRY_PATH) -TEST_ENTRY_DICT: Dict = dict(name=TEST_ENTRY_PATH) +TEST_ENTRY_DICT: Dict = { + 'description': '', + 'display_name': '', + 'linked_resource': '', + 'name': TEST_ENTRY_PATH, +} TEST_ENTRY_GROUP: EntryGroup = EntryGroup(name=TEST_ENTRY_GROUP_PATH) -TEST_ENTRY_GROUP_DICT: Dict = dict(name=TEST_ENTRY_GROUP_PATH) -TEST_TAG: EntryGroup = Tag(name=TEST_TAG_PATH) -TEST_TAG_DICT: Dict = dict(name=TEST_TAG_PATH) +TEST_ENTRY_GROUP_DICT: Dict = {'description': '', 'display_name': '', 'name': TEST_ENTRY_GROUP_PATH} +TEST_TAG: Tag = Tag(name=TEST_TAG_PATH) +TEST_TAG_DICT: Dict = {'fields': {}, 'name': TEST_TAG_PATH, 'template': '', 'template_display_name': ''} TEST_TAG_TEMPLATE: TagTemplate = TagTemplate(name=TEST_TAG_TEMPLATE_PATH) -TEST_TAG_TEMPLATE_DICT: Dict = dict(name=TEST_TAG_TEMPLATE_PATH) -TEST_TAG_TEMPLATE_FIELD: Dict = TagTemplateField(name=TEST_TAG_TEMPLATE_FIELD_ID) -TEST_TAG_TEMPLATE_FIELD_DICT: Dict = dict(name=TEST_TAG_TEMPLATE_FIELD_ID) +TEST_TAG_TEMPLATE_DICT: Dict = {'display_name': '', 'fields': {}, 'name': TEST_TAG_TEMPLATE_PATH} +TEST_TAG_TEMPLATE_FIELD: TagTemplateField = TagTemplateField(name=TEST_TAG_TEMPLATE_FIELD_ID) +TEST_TAG_TEMPLATE_FIELD_DICT: Dict = { + 'display_name': '', + 'is_required': False, + 'name': TEST_TAG_TEMPLATE_FIELD_ID, + 'order': 0, +} class TestCloudDataCatalogCreateEntryOperator(TestCase): @@ -498,7 +508,10 @@ class TestCloudDataCatalogDeleteTagTemplateFieldOperator(TestCase): class TestCloudDataCatalogGetEntryOperator(TestCase): - @mock.patch("airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook") + @mock.patch( + "airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook", + **{"return_value.get_entry.return_value": TEST_ENTRY}, # type: ignore + ) def test_assert_valid_hook_call(self, mock_hook) -> None: task = CloudDataCatalogGetEntryOperator( task_id="task_id", @@ -529,7 +542,10 @@ class TestCloudDataCatalogGetEntryOperator(TestCase): class TestCloudDataCatalogGetEntryGroupOperator(TestCase): - @mock.patch("airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook") + @mock.patch( + "airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook", + **{"return_value.get_entry_group.return_value": TEST_ENTRY_GROUP}, # type: ignore + ) def test_assert_valid_hook_call(self, mock_hook) -> None: task = CloudDataCatalogGetEntryGroupOperator( task_id="task_id", @@ -560,7 +576,10 @@ class TestCloudDataCatalogGetEntryGroupOperator(TestCase): class TestCloudDataCatalogGetTagTemplateOperator(TestCase): - @mock.patch("airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook") + @mock.patch( + "airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook", + **{"return_value.get_tag_template.return_value": TEST_TAG_TEMPLATE}, # type: ignore + ) def test_assert_valid_hook_call(self, mock_hook) -> None: task = CloudDataCatalogGetTagTemplateOperator( task_id="task_id", @@ -589,7 +608,10 @@ class TestCloudDataCatalogGetTagTemplateOperator(TestCase): class TestCloudDataCatalogListTagsOperator(TestCase): - @mock.patch("airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook") + @mock.patch( + "airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook", + **{"return_value.list_tags.return_value": [TEST_TAG]}, # type: ignore + ) def test_assert_valid_hook_call(self, mock_hook) -> None: task = CloudDataCatalogListTagsOperator( task_id="task_id", @@ -622,7 +644,10 @@ class TestCloudDataCatalogListTagsOperator(TestCase): class TestCloudDataCatalogLookupEntryOperator(TestCase): - @mock.patch("airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook") + @mock.patch( + "airflow.providers.google.cloud.operators.datacatalog.CloudDataCatalogHook", + **{"return_value.lookup_entry.return_value": TEST_ENTRY}, # type: ignore + ) def test_assert_valid_hook_call(self, mock_hook) -> None: task = CloudDataCatalogLookupEntryOperator( task_id="task_id",
