This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch v2-0-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit a3b6e47229da2f7da982d1f6afad6aab394bbd27 Author: Kamil BreguĊa <[email protected]> AuthorDate: Mon Jan 11 09:39:44 2021 +0100 Support google-cloud-automl >=2.1.0 (#13505) (cherry picked from commit a6f999b62e3c9aeb10ab24342674d3670a8ad259) --- airflow/providers/google/ADDITIONAL_INFO.md | 1 + .../cloud/example_dags/example_automl_tables.py | 6 +- airflow/providers/google/cloud/hooks/automl.py | 103 +++++++++++---------- airflow/providers/google/cloud/operators/automl.py | 36 +++---- setup.py | 2 +- tests/providers/google/cloud/hooks/test_automl.py | 70 +++++++------- .../google/cloud/operators/test_automl.py | 29 ++++-- 7 files changed, 134 insertions(+), 113 deletions(-) diff --git a/airflow/providers/google/ADDITIONAL_INFO.md b/airflow/providers/google/ADDITIONAL_INFO.md index d80f9e1..800703b 100644 --- a/airflow/providers/google/ADDITIONAL_INFO.md +++ b/airflow/providers/google/ADDITIONAL_INFO.md @@ -29,6 +29,7 @@ Details are covered in the UPDATING.md files for each library, but there are som | Library name | Previous constraints | Current constraints | | | --- | --- | --- | --- | +| [``google-cloud-automl``](https://pypi.org/project/google-cloud-automl/) | ``>=0.4.0,<2.0.0`` | ``>=2.1.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-bigquery-automl/blob/master/UPGRADING.md) | | [``google-cloud-bigquery-datatransfer``](https://pypi.org/project/google-cloud-bigquery-datatransfer/) | ``>=0.4.0,<2.0.0`` | ``>=3.0.0,<4.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-bigquery-datatransfer/blob/master/UPGRADING.md) | | [``google-cloud-datacatalog``](https://pypi.org/project/google-cloud-datacatalog/) | ``>=0.5.0,<0.8`` | ``>=3.0.0,<4.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-datacatalog/blob/master/UPGRADING.md) | | [``google-cloud-os-login``](https://pypi.org/project/google-cloud-os-login/) | ``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-oslogin/blob/master/UPGRADING.md) | diff --git a/airflow/providers/google/cloud/example_dags/example_automl_tables.py b/airflow/providers/google/cloud/example_dags/example_automl_tables.py index 4ff92b3..117bd34 100644 --- a/airflow/providers/google/cloud/example_dags/example_automl_tables.py +++ b/airflow/providers/google/cloud/example_dags/example_automl_tables.py @@ -47,7 +47,7 @@ GCP_AUTOML_LOCATION = os.environ.get("GCP_AUTOML_LOCATION", "us-central1") GCP_AUTOML_DATASET_BUCKET = os.environ.get( "GCP_AUTOML_DATASET_BUCKET", "gs://cloud-ml-tables-data/bank-marketing.csv" ) -TARGET = os.environ.get("GCP_AUTOML_TARGET", "Class") +TARGET = os.environ.get("GCP_AUTOML_TARGET", "Deposit") # Example values MODEL_ID = "TBL123456" @@ -76,9 +76,9 @@ def get_target_column_spec(columns_specs: List[Dict], column_name: str) -> str: Using column name returns spec of the column. """ for column in columns_specs: - if column["displayName"] == column_name: + if column["display_name"] == column_name: return extract_object_id(column) - return "" + raise Exception(f"Unknown target column: {column_name}") # Example DAG to create dataset, train model_id and deploy it. diff --git a/airflow/providers/google/cloud/hooks/automl.py b/airflow/providers/google/cloud/hooks/automl.py index 78ec4fb..75d7037 100644 --- a/airflow/providers/google/cloud/hooks/automl.py +++ b/airflow/providers/google/cloud/hooks/automl.py @@ -20,22 +20,23 @@ from typing import Dict, List, Optional, Sequence, Tuple, Union from cached_property import cached_property +from google.api_core.operation import Operation from google.api_core.retry import Retry -from google.cloud.automl_v1beta1 import AutoMlClient, PredictionServiceClient -from google.cloud.automl_v1beta1.types import ( +from google.cloud.automl_v1beta1 import ( + AutoMlClient, BatchPredictInputConfig, BatchPredictOutputConfig, ColumnSpec, Dataset, ExamplePayload, - FieldMask, ImageObjectDetectionModelDeploymentMetadata, InputConfig, Model, - Operation, + PredictionServiceClient, PredictResponse, TableSpec, ) +from google.protobuf.field_mask_pb2 import FieldMask from airflow.providers.google.common.hooks.base_google import GoogleBaseHook @@ -123,9 +124,9 @@ class CloudAutoMLHook(GoogleBaseHook): :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance """ client = self.get_conn() - parent = client.location_path(project_id, location) + parent = f"projects/{project_id}/locations/{location}" return client.create_model( - parent=parent, model=model, retry=retry, timeout=timeout, metadata=metadata + request={'parent': parent, 'model': model}, retry=retry, timeout=timeout, metadata=metadata or () ) @GoogleBaseHook.fallback_to_default_project_id @@ -176,15 +177,17 @@ class CloudAutoMLHook(GoogleBaseHook): :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance """ client = self.prediction_client - name = client.model_path(project=project_id, location=location, model=model_id) + name = f"projects/{project_id}/locations/{location}/models/{model_id}" result = client.batch_predict( - name=name, - input_config=input_config, - output_config=output_config, - params=params, + request={ + 'name': name, + 'input_config': input_config, + 'output_config': output_config, + 'params': params, + }, retry=retry, timeout=timeout, - metadata=metadata, + metadata=metadata or (), ) return result @@ -229,14 +232,12 @@ class CloudAutoMLHook(GoogleBaseHook): :return: `google.cloud.automl_v1beta1.types.PredictResponse` instance """ client = self.prediction_client - name = client.model_path(project=project_id, location=location, model=model_id) + name = f"projects/{project_id}/locations/{location}/models/{model_id}" result = client.predict( - name=name, - payload=payload, - params=params, + request={'name': name, 'payload': payload, 'params': params}, retry=retry, timeout=timeout, - metadata=metadata, + metadata=metadata or (), ) return result @@ -273,13 +274,12 @@ class CloudAutoMLHook(GoogleBaseHook): :return: `google.cloud.automl_v1beta1.types.Dataset` instance. """ client = self.get_conn() - parent = client.location_path(project=project_id, location=location) + parent = f"projects/{project_id}/locations/{location}" result = client.create_dataset( - parent=parent, - dataset=dataset, + request={'parent': parent, 'dataset': dataset}, retry=retry, timeout=timeout, - metadata=metadata, + metadata=metadata or (), ) return result @@ -319,13 +319,12 @@ class CloudAutoMLHook(GoogleBaseHook): :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance """ client = self.get_conn() - name = client.dataset_path(project=project_id, location=location, dataset=dataset_id) + name = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}" result = client.import_data( - name=name, - input_config=input_config, + request={'name': name, 'input_config': input_config}, retry=retry, timeout=timeout, - metadata=metadata, + metadata=metadata or (), ) return result @@ -385,13 +384,10 @@ class CloudAutoMLHook(GoogleBaseHook): table_spec=table_spec_id, ) result = client.list_column_specs( - parent=parent, - field_mask=field_mask, - filter_=filter_, - page_size=page_size, + request={'parent': parent, 'field_mask': field_mask, 'filter': filter_, 'page_size': page_size}, retry=retry, timeout=timeout, - metadata=metadata, + metadata=metadata or (), ) return result @@ -427,8 +423,10 @@ class CloudAutoMLHook(GoogleBaseHook): :return: `google.cloud.automl_v1beta1.types.Model` instance. """ client = self.get_conn() - name = client.model_path(project=project_id, location=location, model=model_id) - result = client.get_model(name=name, retry=retry, timeout=timeout, metadata=metadata) + name = f"projects/{project_id}/locations/{location}/models/{model_id}" + result = client.get_model( + request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or () + ) return result @GoogleBaseHook.fallback_to_default_project_id @@ -463,8 +461,10 @@ class CloudAutoMLHook(GoogleBaseHook): :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance. """ client = self.get_conn() - name = client.model_path(project=project_id, location=location, model=model_id) - result = client.delete_model(name=name, retry=retry, timeout=timeout, metadata=metadata) + name = f"projects/{project_id}/locations/{location}/models/{model_id}" + result = client.delete_model( + request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or () + ) return result def update_dataset( @@ -497,11 +497,10 @@ class CloudAutoMLHook(GoogleBaseHook): """ client = self.get_conn() result = client.update_dataset( - dataset=dataset, - update_mask=update_mask, + request={'dataset': dataset, 'update_mask': update_mask}, retry=retry, timeout=timeout, - metadata=metadata, + metadata=metadata or (), ) return result @@ -547,13 +546,15 @@ class CloudAutoMLHook(GoogleBaseHook): :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance. """ client = self.get_conn() - name = client.model_path(project=project_id, location=location, model=model_id) + name = f"projects/{project_id}/locations/{location}/models/{model_id}" result = client.deploy_model( - name=name, + request={ + 'name': name, + 'image_object_detection_model_deployment_metadata': image_detection_metadata, + }, retry=retry, timeout=timeout, - metadata=metadata, - image_object_detection_model_deployment_metadata=image_detection_metadata, + metadata=metadata or (), ) return result @@ -601,14 +602,12 @@ class CloudAutoMLHook(GoogleBaseHook): of the response through the `options` parameter. """ client = self.get_conn() - parent = client.dataset_path(project=project_id, location=location, dataset=dataset_id) + parent = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}" result = client.list_table_specs( - parent=parent, - filter_=filter_, - page_size=page_size, + request={'parent': parent, 'filter': filter_, 'page_size': page_size}, retry=retry, timeout=timeout, - metadata=metadata, + metadata=metadata or (), ) return result @@ -644,8 +643,10 @@ class CloudAutoMLHook(GoogleBaseHook): of the response through the `options` parameter. """ client = self.get_conn() - parent = client.location_path(project=project_id, location=location) - result = client.list_datasets(parent=parent, retry=retry, timeout=timeout, metadata=metadata) + parent = f"projects/{project_id}/locations/{location}" + result = client.list_datasets( + request={'parent': parent}, retry=retry, timeout=timeout, metadata=metadata or () + ) return result @GoogleBaseHook.fallback_to_default_project_id @@ -680,6 +681,8 @@ class CloudAutoMLHook(GoogleBaseHook): :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance """ client = self.get_conn() - name = client.dataset_path(project=project_id, location=location, dataset=dataset_id) - result = client.delete_dataset(name=name, retry=retry, timeout=timeout, metadata=metadata) + name = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}" + result = client.delete_dataset( + request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or () + ) return result diff --git a/airflow/providers/google/cloud/operators/automl.py b/airflow/providers/google/cloud/operators/automl.py index a1823cd..cdf79b0 100644 --- a/airflow/providers/google/cloud/operators/automl.py +++ b/airflow/providers/google/cloud/operators/automl.py @@ -22,7 +22,14 @@ import ast from typing import Dict, List, Optional, Sequence, Tuple, Union from google.api_core.retry import Retry -from google.protobuf.json_format import MessageToDict +from google.cloud.automl_v1beta1 import ( + BatchPredictResult, + ColumnSpec, + Dataset, + Model, + PredictResponse, + TableSpec, +) from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook @@ -113,7 +120,7 @@ class AutoMLTrainModelOperator(BaseOperator): timeout=self.timeout, metadata=self.metadata, ) - result = MessageToDict(operation.result()) + result = Model.to_dict(operation.result()) model_id = hook.extract_object_id(result) self.log.info("Model created: %s", model_id) @@ -212,7 +219,7 @@ class AutoMLPredictOperator(BaseOperator): timeout=self.timeout, metadata=self.metadata, ) - return MessageToDict(result) + return PredictResponse.to_dict(result) class AutoMLBatchPredictOperator(BaseOperator): @@ -324,7 +331,7 @@ class AutoMLBatchPredictOperator(BaseOperator): timeout=self.timeout, metadata=self.metadata, ) - result = MessageToDict(operation.result()) + result = BatchPredictResult.to_dict(operation.result()) self.log.info("Batch prediction ready.") return result @@ -414,7 +421,7 @@ class AutoMLCreateDatasetOperator(BaseOperator): timeout=self.timeout, metadata=self.metadata, ) - result = MessageToDict(result) + result = Dataset.to_dict(result) dataset_id = hook.extract_object_id(result) self.log.info("Creating completed. Dataset id: %s", dataset_id) @@ -513,9 +520,8 @@ class AutoMLImportDataOperator(BaseOperator): timeout=self.timeout, metadata=self.metadata, ) - result = MessageToDict(operation.result()) + operation.result() self.log.info("Import completed") - return result class AutoMLTablesListColumnSpecsOperator(BaseOperator): @@ -627,7 +633,7 @@ class AutoMLTablesListColumnSpecsOperator(BaseOperator): timeout=self.timeout, metadata=self.metadata, ) - result = [MessageToDict(spec) for spec in page_iterator] + result = [ColumnSpec.to_dict(spec) for spec in page_iterator] self.log.info("Columns specs obtained.") return result @@ -718,7 +724,7 @@ class AutoMLTablesUpdateDatasetOperator(BaseOperator): metadata=self.metadata, ) self.log.info("Dataset updated.") - return MessageToDict(result) + return Dataset.to_dict(result) class AutoMLGetModelOperator(BaseOperator): @@ -804,7 +810,7 @@ class AutoMLGetModelOperator(BaseOperator): timeout=self.timeout, metadata=self.metadata, ) - return MessageToDict(result) + return Model.to_dict(result) class AutoMLDeleteModelOperator(BaseOperator): @@ -890,8 +896,7 @@ class AutoMLDeleteModelOperator(BaseOperator): timeout=self.timeout, metadata=self.metadata, ) - result = MessageToDict(operation.result()) - return result + operation.result() class AutoMLDeployModelOperator(BaseOperator): @@ -991,9 +996,8 @@ class AutoMLDeployModelOperator(BaseOperator): timeout=self.timeout, metadata=self.metadata, ) - result = MessageToDict(operation.result()) + operation.result() self.log.info("Model deployed.") - return result class AutoMLTablesListTableSpecsOperator(BaseOperator): @@ -1092,7 +1096,7 @@ class AutoMLTablesListTableSpecsOperator(BaseOperator): timeout=self.timeout, metadata=self.metadata, ) - result = [MessageToDict(spec) for spec in page_iterator] + result = [TableSpec.to_dict(spec) for spec in page_iterator] self.log.info(result) self.log.info("Table specs obtained.") return result @@ -1173,7 +1177,7 @@ class AutoMLListDatasetOperator(BaseOperator): timeout=self.timeout, metadata=self.metadata, ) - result = [MessageToDict(dataset) for dataset in page_iterator] + result = [Dataset.to_dict(dataset) for dataset in page_iterator] self.log.info("Datasets obtained.") self.xcom_push( diff --git a/setup.py b/setup.py index 5314814..ff9e65d 100644 --- a/setup.py +++ b/setup.py @@ -283,7 +283,7 @@ google = [ 'google-api-python-client>=1.6.0,<2.0.0', 'google-auth>=1.0.0,<2.0.0', 'google-auth-httplib2>=0.0.1', - 'google-cloud-automl>=0.4.0,<2.0.0', + 'google-cloud-automl>=2.1.0,<3.0.0', 'google-cloud-bigquery-datatransfer>=3.0.0,<4.0.0', 'google-cloud-bigtable>=1.0.0,<2.0.0', 'google-cloud-container>=0.1.1,<2.0.0', diff --git a/tests/providers/google/cloud/hooks/test_automl.py b/tests/providers/google/cloud/hooks/test_automl.py index 898001c..c9de712 100644 --- a/tests/providers/google/cloud/hooks/test_automl.py +++ b/tests/providers/google/cloud/hooks/test_automl.py @@ -19,7 +19,7 @@ import unittest from unittest import mock -from google.cloud.automl_v1beta1 import AutoMlClient, PredictionServiceClient +from google.cloud.automl_v1beta1 import AutoMlClient from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook from tests.providers.google.cloud.utils.base_gcp_mock import mock_base_gcp_hook_no_default_project_id @@ -38,9 +38,9 @@ MODEL = { "tables_model_metadata": {"train_budget_milli_node_hours": 1000}, } -LOCATION_PATH = AutoMlClient.location_path(GCP_PROJECT_ID, GCP_LOCATION) -MODEL_PATH = PredictionServiceClient.model_path(GCP_PROJECT_ID, GCP_LOCATION, MODEL_ID) -DATASET_PATH = AutoMlClient.dataset_path(GCP_PROJECT_ID, GCP_LOCATION, DATASET_ID) +LOCATION_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}" +MODEL_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/models/{MODEL_ID}" +DATASET_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/datasets/{DATASET_ID}" INPUT_CONFIG = {"input": "value"} OUTPUT_CONFIG = {"output": "value"} @@ -81,7 +81,7 @@ class TestAuoMLHook(unittest.TestCase): self.hook.create_model(model=MODEL, location=GCP_LOCATION, project_id=GCP_PROJECT_ID) mock_create_model.assert_called_once_with( - parent=LOCATION_PATH, model=MODEL, retry=None, timeout=None, metadata=None + request=dict(parent=LOCATION_PATH, model=MODEL), retry=None, timeout=None, metadata=() ) @mock.patch("airflow.providers.google.cloud.hooks.automl.PredictionServiceClient.batch_predict") @@ -95,13 +95,12 @@ class TestAuoMLHook(unittest.TestCase): ) mock_batch_predict.assert_called_once_with( - name=MODEL_PATH, - input_config=INPUT_CONFIG, - output_config=OUTPUT_CONFIG, - params=None, + request=dict( + name=MODEL_PATH, input_config=INPUT_CONFIG, output_config=OUTPUT_CONFIG, params=None + ), retry=None, timeout=None, - metadata=None, + metadata=(), ) @mock.patch("airflow.providers.google.cloud.hooks.automl.PredictionServiceClient.predict") @@ -114,12 +113,10 @@ class TestAuoMLHook(unittest.TestCase): ) mock_predict.assert_called_once_with( - name=MODEL_PATH, - payload=PAYLOAD, - params=None, + request=dict(name=MODEL_PATH, payload=PAYLOAD, params=None), retry=None, timeout=None, - metadata=None, + metadata=(), ) @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.create_dataset") @@ -127,11 +124,10 @@ class TestAuoMLHook(unittest.TestCase): self.hook.create_dataset(dataset=DATASET, location=GCP_LOCATION, project_id=GCP_PROJECT_ID) mock_create_dataset.assert_called_once_with( - parent=LOCATION_PATH, - dataset=DATASET, + request=dict(parent=LOCATION_PATH, dataset=DATASET), retry=None, timeout=None, - metadata=None, + metadata=(), ) @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.import_data") @@ -144,11 +140,10 @@ class TestAuoMLHook(unittest.TestCase): ) mock_import_data.assert_called_once_with( - name=DATASET_PATH, - input_config=INPUT_CONFIG, + request=dict(name=DATASET_PATH, input_config=INPUT_CONFIG), retry=None, timeout=None, - metadata=None, + metadata=(), ) @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.list_column_specs") @@ -169,26 +164,27 @@ class TestAuoMLHook(unittest.TestCase): parent = AutoMlClient.table_spec_path(GCP_PROJECT_ID, GCP_LOCATION, DATASET_ID, table_spec) mock_list_column_specs.assert_called_once_with( - parent=parent, - field_mask=MASK, - filter_=filter_, - page_size=page_size, + request=dict(parent=parent, field_mask=MASK, filter=filter_, page_size=page_size), retry=None, timeout=None, - metadata=None, + metadata=(), ) @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.get_model") def test_get_model(self, mock_get_model): self.hook.get_model(model_id=MODEL_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID) - mock_get_model.assert_called_once_with(name=MODEL_PATH, retry=None, timeout=None, metadata=None) + mock_get_model.assert_called_once_with( + request=dict(name=MODEL_PATH), retry=None, timeout=None, metadata=() + ) @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.delete_model") def test_delete_model(self, mock_delete_model): self.hook.delete_model(model_id=MODEL_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID) - mock_delete_model.assert_called_once_with(name=MODEL_PATH, retry=None, timeout=None, metadata=None) + mock_delete_model.assert_called_once_with( + request=dict(name=MODEL_PATH), retry=None, timeout=None, metadata=() + ) @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.update_dataset") def test_update_dataset(self, mock_update_dataset): @@ -198,7 +194,7 @@ class TestAuoMLHook(unittest.TestCase): ) mock_update_dataset.assert_called_once_with( - dataset=DATASET, update_mask=MASK, retry=None, timeout=None, metadata=None + request=dict(dataset=DATASET, update_mask=MASK), retry=None, timeout=None, metadata=() ) @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.deploy_model") @@ -213,11 +209,13 @@ class TestAuoMLHook(unittest.TestCase): ) mock_deploy_model.assert_called_once_with( - name=MODEL_PATH, + request=dict( + name=MODEL_PATH, + image_object_detection_model_deployment_metadata=image_detection_metadata, + ), retry=None, timeout=None, - metadata=None, - image_object_detection_model_deployment_metadata=image_detection_metadata, + metadata=(), ) @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.list_table_specs") @@ -234,12 +232,10 @@ class TestAuoMLHook(unittest.TestCase): ) mock_list_table_specs.assert_called_once_with( - parent=DATASET_PATH, - filter_=filter_, - page_size=page_size, + request=dict(parent=DATASET_PATH, filter=filter_, page_size=page_size), retry=None, timeout=None, - metadata=None, + metadata=(), ) @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.list_datasets") @@ -247,7 +243,7 @@ class TestAuoMLHook(unittest.TestCase): self.hook.list_datasets(location=GCP_LOCATION, project_id=GCP_PROJECT_ID) mock_list_datasets.assert_called_once_with( - parent=LOCATION_PATH, retry=None, timeout=None, metadata=None + request=dict(parent=LOCATION_PATH), retry=None, timeout=None, metadata=() ) @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.delete_dataset") @@ -255,5 +251,5 @@ class TestAuoMLHook(unittest.TestCase): self.hook.delete_dataset(dataset_id=DATASET_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID) mock_delete_dataset.assert_called_once_with( - name=DATASET_PATH, retry=None, timeout=None, metadata=None + request=dict(name=DATASET_PATH), retry=None, timeout=None, metadata=() ) diff --git a/tests/providers/google/cloud/operators/test_automl.py b/tests/providers/google/cloud/operators/test_automl.py index 903600b..4c80703 100644 --- a/tests/providers/google/cloud/operators/test_automl.py +++ b/tests/providers/google/cloud/operators/test_automl.py @@ -20,8 +20,9 @@ import copy import unittest from unittest import mock -from google.cloud.automl_v1beta1 import AutoMlClient, PredictionServiceClient +from google.cloud.automl_v1beta1 import BatchPredictResult, Dataset, Model, PredictResponse +from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook from airflow.providers.google.cloud.operators.automl import ( AutoMLBatchPredictOperator, AutoMLCreateDatasetOperator, @@ -43,7 +44,7 @@ TASK_ID = "test-automl-hook" GCP_PROJECT_ID = "test-project" GCP_LOCATION = "test-location" MODEL_NAME = "test_model" -MODEL_ID = "projects/198907790164/locations/us-central1/models/TBL9195602771183665152" +MODEL_ID = "TBL9195602771183665152" DATASET_ID = "TBL123456789" MODEL = { "display_name": MODEL_NAME, @@ -51,8 +52,9 @@ MODEL = { "tables_model_metadata": {"train_budget_milli_node_hours": 1000}, } -LOCATION_PATH = AutoMlClient.location_path(GCP_PROJECT_ID, GCP_LOCATION) -MODEL_PATH = PredictionServiceClient.model_path(GCP_PROJECT_ID, GCP_LOCATION, MODEL_ID) +LOCATION_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}" +MODEL_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/models/{MODEL_ID}" +DATASET_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/datasets/{DATASET_ID}" INPUT_CONFIG = {"input": "value"} OUTPUT_CONFIG = {"output": "value"} @@ -60,12 +62,15 @@ PAYLOAD = {"test": "payload"} DATASET = {"dataset_id": "data"} MASK = {"field": "mask"} +extract_object_id = CloudAutoMLHook.extract_object_id + class TestAutoMLTrainModelOperator(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.automl.AutoMLTrainModelOperator.xcom_push") @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") def test_execute(self, mock_hook, mock_xcom): - mock_hook.return_value.extract_object_id.return_value = MODEL_ID + mock_hook.return_value.create_model.return_value.result.return_value = Model(name=MODEL_PATH) + mock_hook.return_value.extract_object_id = extract_object_id op = AutoMLTrainModelOperator( model=MODEL, location=GCP_LOCATION, @@ -87,6 +92,9 @@ class TestAutoMLTrainModelOperator(unittest.TestCase): class TestAutoMLBatchPredictOperator(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") def test_execute(self, mock_hook): + mock_hook.return_value.batch_predict.return_value.result.return_value = BatchPredictResult() + mock_hook.return_value.extract_object_id = extract_object_id + op = AutoMLBatchPredictOperator( model_id=MODEL_ID, location=GCP_LOCATION, @@ -113,6 +121,8 @@ class TestAutoMLBatchPredictOperator(unittest.TestCase): class TestAutoMLPredictOperator(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") def test_execute(self, mock_hook): + mock_hook.return_value.predict.return_value = PredictResponse() + op = AutoMLPredictOperator( model_id=MODEL_ID, location=GCP_LOCATION, @@ -137,7 +147,9 @@ class TestAutoMLCreateImportOperator(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.automl.AutoMLCreateDatasetOperator.xcom_push") @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") def test_execute(self, mock_hook, mock_xcom): - mock_hook.return_value.extract_object_id.return_value = DATASET_ID + mock_hook.return_value.create_dataset.return_value = Dataset(name=DATASET_PATH) + mock_hook.return_value.extract_object_id = extract_object_id + op = AutoMLCreateDatasetOperator( dataset=DATASET, location=GCP_LOCATION, @@ -191,6 +203,8 @@ class TestAutoMLListColumnsSpecsOperator(unittest.TestCase): class TestAutoMLUpdateDatasetOperator(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") def test_execute(self, mock_hook): + mock_hook.return_value.update_dataset.return_value = Dataset(name=DATASET_PATH) + dataset = copy.deepcopy(DATASET) dataset["name"] = DATASET_ID @@ -213,6 +227,9 @@ class TestAutoMLUpdateDatasetOperator(unittest.TestCase): class TestAutoMLGetModelOperator(unittest.TestCase): @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") def test_execute(self, mock_hook): + mock_hook.return_value.get_model.return_value = Model(name=MODEL_PATH) + mock_hook.return_value.extract_object_id = extract_object_id + op = AutoMLGetModelOperator( model_id=MODEL_ID, location=GCP_LOCATION,
