This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 90e6277da6 Auto ML assets (#25466) 90e6277da6 is described below commit 90e6277da6b4102cf565134739af10bafa9d3894 Author: Maksim <maks.yerma...@gmail.com> AuthorDate: Mon Jan 23 18:19:28 2023 +0300 Auto ML assets (#25466) --- .../cloud/example_dags/example_automl_tables.py | 319 --------------------- airflow/providers/google/cloud/links/automl.py | 163 +++++++++++ airflow/providers/google/cloud/operators/automl.py | 103 ++++++- airflow/providers/google/provider.yaml | 5 + .../operators/cloud/automl.rst | 26 +- .../google/cloud/operators/test_automl.py | 32 +-- .../google/cloud/operators/test_automl_system.py | 41 --- .../google/cloud/utils/gcp_authenticator.py | 1 - .../google/cloud/automl/example_automl_dataset.py | 201 +++++++++++++ .../google/cloud/automl/example_automl_model.py | 285 ++++++++++++++++++ .../google/cloud/automl/resources/__init__.py | 16 ++ 11 files changed, 797 insertions(+), 395 deletions(-) diff --git a/airflow/providers/google/cloud/example_dags/example_automl_tables.py b/airflow/providers/google/cloud/example_dags/example_automl_tables.py deleted file mode 100644 index 89006402f7..0000000000 --- a/airflow/providers/google/cloud/example_dags/example_automl_tables.py +++ /dev/null @@ -1,319 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Example Airflow DAG that uses Google AutoML services. -""" -from __future__ import annotations - -import os -from copy import deepcopy -from datetime import datetime -from typing import cast - -from airflow import models -from airflow.models.xcom_arg import XComArg -from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook -from airflow.providers.google.cloud.operators.automl import ( - AutoMLBatchPredictOperator, - AutoMLCreateDatasetOperator, - AutoMLDeleteDatasetOperator, - AutoMLDeleteModelOperator, - AutoMLDeployModelOperator, - AutoMLGetModelOperator, - AutoMLImportDataOperator, - AutoMLListDatasetOperator, - AutoMLPredictOperator, - AutoMLTablesListColumnSpecsOperator, - AutoMLTablesListTableSpecsOperator, - AutoMLTablesUpdateDatasetOperator, - AutoMLTrainModelOperator, -) - -START_DATE = datetime(2021, 1, 1) - -GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "your-project-id") -GCP_AUTOML_LOCATION = os.environ.get("GCP_AUTOML_LOCATION", "us-central1") -GCP_AUTOML_DATASET_BUCKET = os.environ.get( - "GCP_AUTOML_DATASET_BUCKET", "gs://INVALID BUCKET NAME/bank-marketing.csv" -) -TARGET = os.environ.get("GCP_AUTOML_TARGET", "Deposit") - -# Example values -MODEL_ID = "TBL123456" -DATASET_ID = "TBL123456" - -# Example model -MODEL = { - "display_name": "auto_model_1", - "dataset_id": DATASET_ID, - "tables_model_metadata": {"train_budget_milli_node_hours": 1000}, -} - -# Example dataset -DATASET = { - "display_name": "test_set", - "tables_dataset_metadata": {"target_column_spec_id": ""}, -} - -IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [GCP_AUTOML_DATASET_BUCKET]}} - -extract_object_id = CloudAutoMLHook.extract_object_id - - -def get_target_column_spec(columns_specs: list[dict], column_name: str) -> str: - """ - Using column name returns spec of the column. - """ - for column in columns_specs: - if column["display_name"] == column_name: - return extract_object_id(column) - raise Exception(f"Unknown target column: {column_name}") - - -# Example DAG to create dataset, train model_id and deploy it. -with models.DAG( - "example_create_and_deploy", - start_date=START_DATE, - catchup=False, - user_defined_macros={ - "get_target_column_spec": get_target_column_spec, - "target": TARGET, - "extract_object_id": extract_object_id, - }, - tags=["example"], -) as create_deploy_dag: - # [START howto_operator_automl_create_dataset] - create_dataset_task = AutoMLCreateDatasetOperator( - task_id="create_dataset_task", - dataset=DATASET, - location=GCP_AUTOML_LOCATION, - project_id=GCP_PROJECT_ID, - ) - - dataset_id = cast(str, XComArg(create_dataset_task, key="dataset_id")) - # [END howto_operator_automl_create_dataset] - - MODEL["dataset_id"] = dataset_id - - # [START howto_operator_automl_import_data] - import_dataset_task = AutoMLImportDataOperator( - task_id="import_dataset_task", - dataset_id=dataset_id, - location=GCP_AUTOML_LOCATION, - input_config=IMPORT_INPUT_CONFIG, - ) - # [END howto_operator_automl_import_data] - - # [START howto_operator_automl_specs] - list_tables_spec_task = AutoMLTablesListTableSpecsOperator( - task_id="list_tables_spec_task", - dataset_id=dataset_id, - location=GCP_AUTOML_LOCATION, - project_id=GCP_PROJECT_ID, - ) - # [END howto_operator_automl_specs] - - # [START howto_operator_automl_column_specs] - list_columns_spec_task = AutoMLTablesListColumnSpecsOperator( - task_id="list_columns_spec_task", - dataset_id=dataset_id, - table_spec_id="{{ extract_object_id(task_instance.xcom_pull('list_tables_spec_task')[0]) }}", - location=GCP_AUTOML_LOCATION, - project_id=GCP_PROJECT_ID, - ) - # [END howto_operator_automl_column_specs] - - # [START howto_operator_automl_update_dataset] - update = deepcopy(DATASET) - update["name"] = '{{ task_instance.xcom_pull("create_dataset_task")["name"] }}' - update["tables_dataset_metadata"][ # type: ignore - "target_column_spec_id" - ] = "{{ get_target_column_spec(task_instance.xcom_pull('list_columns_spec_task'), target) }}" - - update_dataset_task = AutoMLTablesUpdateDatasetOperator( - task_id="update_dataset_task", - dataset=update, - location=GCP_AUTOML_LOCATION, - ) - # [END howto_operator_automl_update_dataset] - - # [START howto_operator_automl_create_model] - create_model_task = AutoMLTrainModelOperator( - task_id="create_model_task", - model=MODEL, - location=GCP_AUTOML_LOCATION, - project_id=GCP_PROJECT_ID, - ) - - model_id = cast(str, XComArg(create_model_task, key="model_id")) - # [END howto_operator_automl_create_model] - - # [START howto_operator_automl_delete_model] - delete_model_task = AutoMLDeleteModelOperator( - task_id="delete_model_task", - model_id=model_id, - location=GCP_AUTOML_LOCATION, - project_id=GCP_PROJECT_ID, - ) - # [END howto_operator_automl_delete_model] - - delete_datasets_task = AutoMLDeleteDatasetOperator( - task_id="delete_datasets_task", - dataset_id=dataset_id, - location=GCP_AUTOML_LOCATION, - project_id=GCP_PROJECT_ID, - ) - - ( - import_dataset_task - >> list_tables_spec_task - >> list_columns_spec_task - >> update_dataset_task - >> create_model_task - ) - delete_model_task >> delete_datasets_task - - # Task dependencies created via `XComArgs`: - # create_dataset_task >> import_dataset_task - # create_dataset_task >> list_tables_spec_task - # create_dataset_task >> list_columns_spec_task - # create_dataset_task >> create_model_task - # create_model_task >> delete_model_task - # create_dataset_task >> delete_datasets_task - - -# Example DAG for AutoML datasets operations -with models.DAG( - "example_automl_dataset", - start_date=START_DATE, - catchup=False, - user_defined_macros={"extract_object_id": extract_object_id}, -) as example_dag: - create_dataset_task2 = AutoMLCreateDatasetOperator( - task_id="create_dataset_task", - dataset=DATASET, - location=GCP_AUTOML_LOCATION, - project_id=GCP_PROJECT_ID, - ) - - dataset_id = cast(str, XComArg(create_dataset_task2, key="dataset_id")) - - import_dataset_task = AutoMLImportDataOperator( - task_id="import_dataset_task", - dataset_id=dataset_id, - location=GCP_AUTOML_LOCATION, - input_config=IMPORT_INPUT_CONFIG, - ) - - list_tables_spec_task = AutoMLTablesListTableSpecsOperator( - task_id="list_tables_spec_task", - dataset_id=dataset_id, - location=GCP_AUTOML_LOCATION, - project_id=GCP_PROJECT_ID, - ) - - list_columns_spec_task = AutoMLTablesListColumnSpecsOperator( - task_id="list_columns_spec_task", - dataset_id=dataset_id, - table_spec_id="{{ extract_object_id(task_instance.xcom_pull('list_tables_spec_task')[0]) }}", - location=GCP_AUTOML_LOCATION, - project_id=GCP_PROJECT_ID, - ) - - # [START howto_operator_list_dataset] - list_datasets_task = AutoMLListDatasetOperator( - task_id="list_datasets_task", - location=GCP_AUTOML_LOCATION, - project_id=GCP_PROJECT_ID, - ) - # [END howto_operator_list_dataset] - - # [START howto_operator_delete_dataset] - delete_datasets_task = AutoMLDeleteDatasetOperator( - task_id="delete_datasets_task", - dataset_id="{{ task_instance.xcom_pull('list_datasets_task', key='dataset_id_list') | list }}", - location=GCP_AUTOML_LOCATION, - project_id=GCP_PROJECT_ID, - ) - # [END howto_operator_delete_dataset] - - ( - import_dataset_task - >> list_tables_spec_task - >> list_columns_spec_task - >> list_datasets_task - >> delete_datasets_task - ) - - # Task dependencies created via `XComArgs`: - # create_dataset_task >> import_dataset_task - # create_dataset_task >> list_tables_spec_task - # create_dataset_task >> list_columns_spec_task - - -with models.DAG( - "example_gcp_get_deploy", - start_date=START_DATE, - catchup=False, - tags=["example"], -) as get_deploy_dag: - # [START howto_operator_get_model] - get_model_task = AutoMLGetModelOperator( - task_id="get_model_task", - model_id=MODEL_ID, - location=GCP_AUTOML_LOCATION, - project_id=GCP_PROJECT_ID, - ) - # [END howto_operator_get_model] - - # [START howto_operator_deploy_model] - deploy_model_task = AutoMLDeployModelOperator( - task_id="deploy_model_task", - model_id=MODEL_ID, - location=GCP_AUTOML_LOCATION, - project_id=GCP_PROJECT_ID, - ) - # [END howto_operator_deploy_model] - - -with models.DAG( - "example_gcp_predict", - start_date=START_DATE, - catchup=False, - tags=["example"], -) as predict_dag: - # [START howto_operator_prediction] - predict_task = AutoMLPredictOperator( - task_id="predict_task", - model_id=MODEL_ID, - payload={}, # Add your own payload, the used model_id must be deployed - location=GCP_AUTOML_LOCATION, - project_id=GCP_PROJECT_ID, - ) - # [END howto_operator_prediction] - - # [START howto_operator_batch_prediction] - batch_predict_task = AutoMLBatchPredictOperator( - task_id="batch_predict_task", - model_id=MODEL_ID, - input_config={}, # Add your config - output_config={}, # Add your config - location=GCP_AUTOML_LOCATION, - project_id=GCP_PROJECT_ID, - ) - # [END howto_operator_batch_prediction] diff --git a/airflow/providers/google/cloud/links/automl.py b/airflow/providers/google/cloud/links/automl.py new file mode 100644 index 0000000000..f2deee6642 --- /dev/null +++ b/airflow/providers/google/cloud/links/automl.py @@ -0,0 +1,163 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google AutoML links.""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow.providers.google.cloud.links.base import BaseGoogleLink + +if TYPE_CHECKING: + from airflow.utils.context import Context + +AUTOML_BASE_LINK = "https://console.cloud.google.com/automl-tables" +AUTOML_DATASET_LINK = ( + AUTOML_BASE_LINK + "/locations/{location}/datasets/{dataset_id}/schemav2?project={project_id}" +) +AUTOML_DATASET_LIST_LINK = AUTOML_BASE_LINK + "/datasets?project={project_id}" +AUTOML_MODEL_LINK = ( + AUTOML_BASE_LINK + + "/locations/{location}/datasets/{dataset_id};modelId={model_id}/evaluate?project={project_id}" +) +AUTOML_MODEL_TRAIN_LINK = ( + AUTOML_BASE_LINK + "/locations/{location}/datasets/{dataset_id}/train?project={project_id}" +) +AUTOML_MODEL_PREDICT_LINK = ( + AUTOML_BASE_LINK + + "/locations/{location}/datasets/{dataset_id};modelId={model_id}/predict?project={project_id}" +) + + +class AutoMLDatasetLink(BaseGoogleLink): + """Helper class for constructing AutoML Dataset link""" + + name = "AutoML Dataset" + key = "automl_dataset" + format_str = AUTOML_DATASET_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + dataset_id: str, + project_id: str, + ): + task_instance.xcom_push( + context, + key=AutoMLDatasetLink.key, + value={"location": task_instance.location, "dataset_id": dataset_id, "project_id": project_id}, + ) + + +class AutoMLDatasetListLink(BaseGoogleLink): + """Helper class for constructing AutoML Dataset List link""" + + name = "AutoML Dataset List" + key = "automl_dataset_list" + format_str = AUTOML_DATASET_LIST_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + project_id: str, + ): + task_instance.xcom_push( + context, + key=AutoMLDatasetListLink.key, + value={ + "project_id": project_id, + }, + ) + + +class AutoMLModelLink(BaseGoogleLink): + """Helper class for constructing AutoML Model link""" + + name = "AutoML Model" + key = "automl_model" + format_str = AUTOML_MODEL_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + dataset_id: str, + model_id: str, + project_id: str, + ): + task_instance.xcom_push( + context, + key=AutoMLModelLink.key, + value={ + "location": task_instance.location, + "dataset_id": dataset_id, + "model_id": model_id, + "project_id": project_id, + }, + ) + + +class AutoMLModelTrainLink(BaseGoogleLink): + """Helper class for constructing AutoML Model Train link""" + + name = "AutoML Model Train" + key = "automl_model_train" + format_str = AUTOML_MODEL_TRAIN_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + project_id: str, + ): + task_instance.xcom_push( + context, + key=AutoMLModelTrainLink.key, + value={ + "location": task_instance.location, + "dataset_id": task_instance.model["dataset_id"], + "project_id": project_id, + }, + ) + + +class AutoMLModelPredictLink(BaseGoogleLink): + """Helper class for constructing AutoML Model Predict link""" + + name = "AutoML Model Predict" + key = "automl_model_predict" + format_str = AUTOML_MODEL_PREDICT_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + model_id: str, + project_id: str, + ): + task_instance.xcom_push( + context, + key=AutoMLModelPredictLink.key, + value={ + "location": task_instance.location, + "dataset_id": "-", + "model_id": model_id, + "project_id": project_id, + }, + ) diff --git a/airflow/providers/google/cloud/operators/automl.py b/airflow/providers/google/cloud/operators/automl.py index ff1ee00c3e..b1da9d3b77 100644 --- a/airflow/providers/google/cloud/operators/automl.py +++ b/airflow/providers/google/cloud/operators/automl.py @@ -34,6 +34,13 @@ from google.cloud.automl_v1beta1 import ( from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook +from airflow.providers.google.cloud.links.automl import ( + AutoMLDatasetLink, + AutoMLDatasetListLink, + AutoMLModelLink, + AutoMLModelPredictLink, + AutoMLModelTrainLink, +) if TYPE_CHECKING: from airflow.utils.context import Context @@ -75,6 +82,10 @@ class AutoMLTrainModelOperator(BaseOperator): "project_id", "impersonation_chain", ) + operator_extra_links = ( + AutoMLModelTrainLink(), + AutoMLModelLink(), + ) def __init__( self, @@ -114,11 +125,22 @@ class AutoMLTrainModelOperator(BaseOperator): timeout=self.timeout, metadata=self.metadata, ) + project_id = self.project_id or hook.project_id + if project_id: + AutoMLModelTrainLink.persist(context=context, task_instance=self, project_id=project_id) result = Model.to_dict(operation.result()) model_id = hook.extract_object_id(result) self.log.info("Model created: %s", model_id) self.xcom_push(context, key="model_id", value=model_id) + if project_id: + AutoMLModelLink.persist( + context=context, + task_instance=self, + dataset_id=self.model["dataset_id"] or "-", + model_id=model_id, + project_id=project_id, + ) return result @@ -158,6 +180,7 @@ class AutoMLPredictOperator(BaseOperator): "project_id", "impersonation_chain", ) + operator_extra_links = (AutoMLModelPredictLink(),) def __init__( self, @@ -202,6 +225,14 @@ class AutoMLPredictOperator(BaseOperator): timeout=self.timeout, metadata=self.metadata, ) + project_id = self.project_id or hook.project_id + if project_id: + AutoMLModelPredictLink.persist( + context=context, + task_instance=self, + model_id=self.model_id, + project_id=project_id, + ) return PredictResponse.to_dict(result) @@ -252,6 +283,7 @@ class AutoMLBatchPredictOperator(BaseOperator): "project_id", "impersonation_chain", ) + operator_extra_links = (AutoMLModelPredictLink(),) def __init__( self, @@ -302,6 +334,14 @@ class AutoMLBatchPredictOperator(BaseOperator): ) result = BatchPredictResult.to_dict(operation.result()) self.log.info("Batch prediction ready.") + project_id = self.project_id or hook.project_id + if project_id: + AutoMLModelPredictLink.persist( + context=context, + task_instance=self, + model_id=self.model_id, + project_id=project_id, + ) return result @@ -341,6 +381,7 @@ class AutoMLCreateDatasetOperator(BaseOperator): "project_id", "impersonation_chain", ) + operator_extra_links = (AutoMLDatasetLink(),) def __init__( self, @@ -385,6 +426,14 @@ class AutoMLCreateDatasetOperator(BaseOperator): self.log.info("Creating completed. Dataset id: %s", dataset_id) self.xcom_push(context, key="dataset_id", value=dataset_id) + project_id = self.project_id or hook.project_id + if project_id: + AutoMLDatasetLink.persist( + context=context, + task_instance=self, + dataset_id=dataset_id, + project_id=project_id, + ) return result @@ -426,6 +475,7 @@ class AutoMLImportDataOperator(BaseOperator): "project_id", "impersonation_chain", ) + operator_extra_links = (AutoMLDatasetLink(),) def __init__( self, @@ -470,6 +520,14 @@ class AutoMLImportDataOperator(BaseOperator): ) operation.result() self.log.info("Import completed") + project_id = self.project_id or hook.project_id + if project_id: + AutoMLDatasetLink.persist( + context=context, + task_instance=self, + dataset_id=self.dataset_id, + project_id=project_id, + ) class AutoMLTablesListColumnSpecsOperator(BaseOperator): @@ -518,6 +576,7 @@ class AutoMLTablesListColumnSpecsOperator(BaseOperator): "project_id", "impersonation_chain", ) + operator_extra_links = (AutoMLDatasetLink(),) def __init__( self, @@ -570,7 +629,14 @@ class AutoMLTablesListColumnSpecsOperator(BaseOperator): ) result = [ColumnSpec.to_dict(spec) for spec in page_iterator] self.log.info("Columns specs obtained.") - + project_id = self.project_id or hook.project_id + if project_id: + AutoMLDatasetLink.persist( + context=context, + task_instance=self, + dataset_id=self.dataset_id, + project_id=project_id, + ) return result @@ -610,6 +676,7 @@ class AutoMLTablesUpdateDatasetOperator(BaseOperator): "location", "impersonation_chain", ) + operator_extra_links = (AutoMLDatasetLink(),) def __init__( self, @@ -649,6 +716,14 @@ class AutoMLTablesUpdateDatasetOperator(BaseOperator): metadata=self.metadata, ) self.log.info("Dataset updated.") + project_id = hook.project_id + if project_id: + AutoMLDatasetLink.persist( + context=context, + task_instance=self, + dataset_id=hook.extract_object_id(self.dataset), + project_id=project_id, + ) return Dataset.to_dict(result) @@ -687,6 +762,7 @@ class AutoMLGetModelOperator(BaseOperator): "project_id", "impersonation_chain", ) + operator_extra_links = (AutoMLModelLink(),) def __init__( self, @@ -725,7 +801,17 @@ class AutoMLGetModelOperator(BaseOperator): timeout=self.timeout, metadata=self.metadata, ) - return Model.to_dict(result) + model = Model.to_dict(result) + project_id = self.project_id or hook.project_id + if project_id: + AutoMLModelLink.persist( + context=context, + task_instance=self, + dataset_id=model["dataset_id"], + model_id=self.model_id, + project_id=project_id, + ) + return model class AutoMLDeleteModelOperator(BaseOperator): @@ -935,6 +1021,7 @@ class AutoMLTablesListTableSpecsOperator(BaseOperator): "project_id", "impersonation_chain", ) + operator_extra_links = (AutoMLDatasetLink(),) def __init__( self, @@ -982,6 +1069,14 @@ class AutoMLTablesListTableSpecsOperator(BaseOperator): result = [TableSpec.to_dict(spec) for spec in page_iterator] self.log.info(result) self.log.info("Table specs obtained.") + project_id = self.project_id or hook.project_id + if project_id: + AutoMLDatasetLink.persist( + context=context, + task_instance=self, + dataset_id=self.dataset_id, + project_id=project_id, + ) return result @@ -1017,6 +1112,7 @@ class AutoMLListDatasetOperator(BaseOperator): "project_id", "impersonation_chain", ) + operator_extra_links = (AutoMLDatasetListLink(),) def __init__( self, @@ -1060,6 +1156,9 @@ class AutoMLListDatasetOperator(BaseOperator): key="dataset_id_list", value=[hook.extract_object_id(d) for d in result], ) + project_id = self.project_id or hook.project_id + if project_id: + AutoMLDatasetListLink.persist(context=context, task_instance=self, project_id=project_id) return result diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml index 43c1f53e44..6ee301dc48 100644 --- a/airflow/providers/google/provider.yaml +++ b/airflow/providers/google/provider.yaml @@ -1034,6 +1034,11 @@ extra-links: - airflow.providers.google.cloud.links.cloud_build.CloudBuildListLink - airflow.providers.google.cloud.links.cloud_build.CloudBuildTriggersListLink - airflow.providers.google.cloud.links.cloud_build.CloudBuildTriggerDetailsLink + - airflow.providers.google.cloud.links.automl.AutoMLDatasetLink + - airflow.providers.google.cloud.links.automl.AutoMLDatasetListLink + - airflow.providers.google.cloud.links.automl.AutoMLModelLink + - airflow.providers.google.cloud.links.automl.AutoMLModelTrainLink + - airflow.providers.google.cloud.links.automl.AutoMLModelPredictLink - airflow.providers.google.cloud.links.life_sciences.LifeSciencesLink - airflow.providers.google.cloud.links.cloud_functions.CloudFunctionsDetailsLink - airflow.providers.google.cloud.links.cloud_functions.CloudFunctionsListLink diff --git a/docs/apache-airflow-providers-google/operators/cloud/automl.rst b/docs/apache-airflow-providers-google/operators/cloud/automl.rst index 4e93a4f5e3..fd4fbbc2d2 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/automl.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/automl.rst @@ -41,7 +41,7 @@ To create a Google AutoML dataset you can use :class:`~airflow.providers.google.cloud.operators.automl.AutoMLCreateDatasetOperator`. The operator returns dataset id in :ref:`XCom <concepts:xcom>` under ``dataset_id`` key. -.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_automl_tables.py +.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_dataset.py :language: python :dedent: 4 :start-after: [START howto_operator_automl_create_dataset] @@ -50,7 +50,7 @@ The operator returns dataset id in :ref:`XCom <concepts:xcom>` under ``dataset_i After creating a dataset you can use it to import some data using :class:`~airflow.providers.google.cloud.operators.automl.AutoMLImportDataOperator`. -.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_automl_tables.py +.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_dataset.py :language: python :dedent: 4 :start-after: [START howto_operator_automl_import_data] @@ -59,7 +59,7 @@ After creating a dataset you can use it to import some data using To update dataset you can use :class:`~airflow.providers.google.cloud.operators.automl.AutoMLTablesUpdateDatasetOperator`. -.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_automl_tables.py +.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_dataset.py :language: python :dedent: 4 :start-after: [START howto_operator_automl_update_dataset] @@ -74,7 +74,7 @@ Listing Table And Columns Specs To list table specs you can use :class:`~airflow.providers.google.cloud.operators.automl.AutoMLTablesListTableSpecsOperator`. -.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_automl_tables.py +.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_dataset.py :language: python :dedent: 4 :start-after: [START howto_operator_automl_specs] @@ -83,7 +83,7 @@ To list table specs you can use To list column specs you can use :class:`~airflow.providers.google.cloud.operators.automl.AutoMLTablesListColumnSpecsOperator`. -.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_automl_tables.py +.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_dataset.py :language: python :dedent: 4 :start-after: [START howto_operator_automl_column_specs] @@ -102,7 +102,7 @@ To create a Google AutoML model you can use The operator will wait for the operation to complete. Additionally the operator returns the id of model in :ref:`XCom <concepts:xcom>` under ``model_id`` key. -.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_automl_tables.py +.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_model.py :language: python :dedent: 4 :start-after: [START howto_operator_automl_create_model] @@ -111,7 +111,7 @@ returns the id of model in :ref:`XCom <concepts:xcom>` under ``model_id`` key. To get existing model one can use :class:`~airflow.providers.google.cloud.operators.automl.AutoMLGetModelOperator`. -.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_automl_tables.py +.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_model.py :language: python :dedent: 4 :start-after: [START howto_operator_get_model] @@ -120,7 +120,7 @@ To get existing model one can use Once a model is created it could be deployed using :class:`~airflow.providers.google.cloud.operators.automl.AutoMLDeployModelOperator`. -.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_automl_tables.py +.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_model.py :language: python :dedent: 4 :start-after: [START howto_operator_deploy_model] @@ -129,7 +129,7 @@ Once a model is created it could be deployed using If you wish to delete a model you can use :class:`~airflow.providers.google.cloud.operators.automl.AutoMLDeleteModelOperator`. -.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_automl_tables.py +.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_model.py :language: python :dedent: 4 :start-after: [START howto_operator_automl_delete_model] @@ -146,13 +146,13 @@ To obtain predictions from Google Cloud AutoML model you can use :class:`~airflow.providers.google.cloud.operators.automl.AutoMLBatchPredictOperator`. In the first case the model must be deployed. -.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_automl_tables.py +.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_model.py :language: python :dedent: 4 :start-after: [START howto_operator_prediction] :end-before: [END howto_operator_prediction] -.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_automl_tables.py +.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_model.py :language: python :dedent: 4 :start-after: [START howto_operator_batch_prediction] @@ -168,7 +168,7 @@ You can get a list of AutoML models using :class:`~airflow.providers.google.cloud.operators.automl.AutoMLListDatasetOperator`. The operator returns list of datasets ids in :ref:`XCom <concepts:xcom>` under ``dataset_id_list`` key. -.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_automl_tables.py +.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_dataset.py :language: python :dedent: 4 :start-after: [START howto_operator_list_dataset] @@ -177,7 +177,7 @@ of datasets ids in :ref:`XCom <concepts:xcom>` under ``dataset_id_list`` key. To delete a model you can use :class:`~airflow.providers.google.cloud.operators.automl.AutoMLDeleteDatasetOperator`. The delete operator allows also to pass list or coma separated string of datasets ids to be deleted. -.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_automl_tables.py +.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_dataset.py :language: python :dedent: 4 :start-after: [START howto_operator_delete_dataset] diff --git a/tests/providers/google/cloud/operators/test_automl.py b/tests/providers/google/cloud/operators/test_automl.py index b02a0e2509..a39e368698 100644 --- a/tests/providers/google/cloud/operators/test_automl.py +++ b/tests/providers/google/cloud/operators/test_automl.py @@ -68,9 +68,8 @@ extract_object_id = CloudAutoMLHook.extract_object_id class TestAutoMLTrainModelOperator(unittest.TestCase): - @mock.patch("airflow.providers.google.cloud.operators.automl.AutoMLTrainModelOperator.xcom_push") @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") - def test_execute(self, mock_hook, mock_xcom): + def test_execute(self, mock_hook): mock_hook.return_value.create_model.return_value.result.return_value = Model(name=MODEL_PATH) mock_hook.return_value.extract_object_id = extract_object_id op = AutoMLTrainModelOperator( @@ -79,7 +78,7 @@ class TestAutoMLTrainModelOperator(unittest.TestCase): project_id=GCP_PROJECT_ID, task_id=TASK_ID, ) - op.execute(context=None) + op.execute(context=mock.MagicMock()) mock_hook.return_value.create_model.assert_called_once_with( model=MODEL, location=GCP_LOCATION, @@ -88,7 +87,6 @@ class TestAutoMLTrainModelOperator(unittest.TestCase): timeout=None, metadata=(), ) - mock_xcom.assert_called_once_with(None, key="model_id", value=MODEL_ID) class TestAutoMLBatchPredictOperator(unittest.TestCase): @@ -106,7 +104,7 @@ class TestAutoMLBatchPredictOperator(unittest.TestCase): task_id=TASK_ID, prediction_params={}, ) - op.execute(context=None) + op.execute(context=mock.MagicMock()) mock_hook.return_value.batch_predict.assert_called_once_with( input_config=INPUT_CONFIG, location=GCP_LOCATION, @@ -133,7 +131,7 @@ class TestAutoMLPredictOperator(unittest.TestCase): task_id=TASK_ID, operation_params={"TEST_KEY": "TEST_VALUE"}, ) - op.execute(context=None) + op.execute(context=mock.MagicMock()) mock_hook.return_value.predict.assert_called_once_with( location=GCP_LOCATION, metadata=(), @@ -147,9 +145,8 @@ class TestAutoMLPredictOperator(unittest.TestCase): class TestAutoMLCreateImportOperator(unittest.TestCase): - @mock.patch("airflow.providers.google.cloud.operators.automl.AutoMLCreateDatasetOperator.xcom_push") @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") - def test_execute(self, mock_hook, mock_xcom): + def test_execute(self, mock_hook): mock_hook.return_value.create_dataset.return_value = Dataset(name=DATASET_PATH) mock_hook.return_value.extract_object_id = extract_object_id @@ -159,7 +156,7 @@ class TestAutoMLCreateImportOperator(unittest.TestCase): project_id=GCP_PROJECT_ID, task_id=TASK_ID, ) - op.execute(context=None) + op.execute(context=mock.MagicMock()) mock_hook.return_value.create_dataset.assert_called_once_with( dataset=DATASET, location=GCP_LOCATION, @@ -168,7 +165,6 @@ class TestAutoMLCreateImportOperator(unittest.TestCase): retry=DEFAULT, timeout=None, ) - mock_xcom.assert_called_once_with(None, key="dataset_id", value=DATASET_ID) class TestAutoMLListColumnsSpecsOperator(unittest.TestCase): @@ -188,7 +184,7 @@ class TestAutoMLListColumnsSpecsOperator(unittest.TestCase): page_size=page_size, task_id=TASK_ID, ) - op.execute(context=None) + op.execute(context=mock.MagicMock()) mock_hook.return_value.list_column_specs.assert_called_once_with( dataset_id=DATASET_ID, field_mask=MASK, @@ -217,7 +213,7 @@ class TestAutoMLUpdateDatasetOperator(unittest.TestCase): location=GCP_LOCATION, task_id=TASK_ID, ) - op.execute(context=None) + op.execute(context=mock.MagicMock()) mock_hook.return_value.update_dataset.assert_called_once_with( dataset=dataset, metadata=(), @@ -239,7 +235,7 @@ class TestAutoMLGetModelOperator(unittest.TestCase): project_id=GCP_PROJECT_ID, task_id=TASK_ID, ) - op.execute(context=None) + op.execute(context=mock.MagicMock()) mock_hook.return_value.get_model.assert_called_once_with( location=GCP_LOCATION, metadata=(), @@ -303,7 +299,7 @@ class TestAutoMLDatasetImportOperator(unittest.TestCase): input_config=INPUT_CONFIG, task_id=TASK_ID, ) - op.execute(context=None) + op.execute(context=mock.MagicMock()) mock_hook.return_value.import_data.assert_called_once_with( input_config=INPUT_CONFIG, location=GCP_LOCATION, @@ -329,7 +325,7 @@ class TestAutoMLTablesListTableSpecsOperator(unittest.TestCase): page_size=page_size, task_id=TASK_ID, ) - op.execute(context=None) + op.execute(context=mock.MagicMock()) mock_hook.return_value.list_table_specs.assert_called_once_with( dataset_id=DATASET_ID, filter_=filter_, @@ -343,11 +339,10 @@ class TestAutoMLTablesListTableSpecsOperator(unittest.TestCase): class TestAutoMLDatasetListOperator(unittest.TestCase): - @mock.patch("airflow.providers.google.cloud.operators.automl.AutoMLListDatasetOperator.xcom_push") @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") - def test_execute(self, mock_hook, mock_xcom): + def test_execute(self, mock_hook): op = AutoMLListDatasetOperator(location=GCP_LOCATION, project_id=GCP_PROJECT_ID, task_id=TASK_ID) - op.execute(context=None) + op.execute(context=mock.MagicMock()) mock_hook.return_value.list_datasets.assert_called_once_with( location=GCP_LOCATION, metadata=(), @@ -355,7 +350,6 @@ class TestAutoMLDatasetListOperator(unittest.TestCase): retry=DEFAULT, timeout=None, ) - mock_xcom.assert_called_once_with(None, key="dataset_id_list", value=[]) class TestAutoMLDatasetDeleteOperator(unittest.TestCase): diff --git a/tests/providers/google/cloud/operators/test_automl_system.py b/tests/providers/google/cloud/operators/test_automl_system.py deleted file mode 100644 index dbda699fda..0000000000 --- a/tests/providers/google/cloud/operators/test_automl_system.py +++ /dev/null @@ -1,41 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import pytest - -from tests.providers.google.cloud.utils.gcp_authenticator import GCP_AUTOML_KEY -from tests.test_utils.gcp_system_helpers import CLOUD_DAG_FOLDER, GoogleSystemTest, provide_gcp_context - - -@pytest.mark.backend("mysql", "postgres") -@pytest.mark.credential_file(GCP_AUTOML_KEY) -@pytest.mark.long_running -class TestAutoMLDatasetOperationsSystem(GoogleSystemTest): - @provide_gcp_context(GCP_AUTOML_KEY) - def test_run_example_dag(self): - self.run_dag("example_automl_dataset", CLOUD_DAG_FOLDER) - - -@pytest.mark.backend("mysql", "postgres") -@pytest.mark.credential_file(GCP_AUTOML_KEY) -@pytest.mark.long_running -class TestAutoMLModelOperationsSystem(GoogleSystemTest): - @provide_gcp_context(GCP_AUTOML_KEY) - def test_run_example_dag(self): - self.run_dag("example_create_and_deploy", CLOUD_DAG_FOLDER) diff --git a/tests/providers/google/cloud/utils/gcp_authenticator.py b/tests/providers/google/cloud/utils/gcp_authenticator.py index f2c9312074..7c95e57dc3 100644 --- a/tests/providers/google/cloud/utils/gcp_authenticator.py +++ b/tests/providers/google/cloud/utils/gcp_authenticator.py @@ -30,7 +30,6 @@ from tests.test_utils import AIRFLOW_MAIN_FOLDER from tests.test_utils.logging_command_executor import CommandExecutor GCP_AI_KEY = "gcp_ai.json" -GCP_AUTOML_KEY = "gcp_automl.json" GCP_BIGQUERY_KEY = "gcp_bigquery.json" GCP_BIGTABLE_KEY = "gcp_bigtable.json" GCP_CLOUD_BUILD_KEY = "gcp_cloud_build.json" diff --git a/tests/system/providers/google/cloud/automl/example_automl_dataset.py b/tests/system/providers/google/cloud/automl/example_automl_dataset.py new file mode 100644 index 0000000000..909da1593a --- /dev/null +++ b/tests/system/providers/google/cloud/automl/example_automl_dataset.py @@ -0,0 +1,201 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG for Google AutoML service testing dataset operations. +""" +from __future__ import annotations + +import os +from copy import deepcopy +from datetime import datetime + +from airflow import models +from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook +from airflow.providers.google.cloud.operators.automl import ( + AutoMLCreateDatasetOperator, + AutoMLDeleteDatasetOperator, + AutoMLImportDataOperator, + AutoMLListDatasetOperator, + AutoMLTablesListColumnSpecsOperator, + AutoMLTablesListTableSpecsOperator, + AutoMLTablesUpdateDatasetOperator, +) +from airflow.providers.google.cloud.operators.gcs import ( + GCSCreateBucketOperator, + GCSDeleteBucketOperator, + GCSSynchronizeBucketsOperator, +) +from airflow.utils.trigger_rule import TriggerRule + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") +DAG_ID = "automl_dataset" +GCP_PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default") + +GCP_AUTOML_LOCATION = "us-central1" + +DATA_SAMPLE_GCS_BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}" +RESOURCE_DATA_BUCKET = "system-tests-resources" + +DATASET_NAME = f"ds_{DAG_ID}_{ENV_ID}" +DATASET = { + "display_name": DATASET_NAME, + "tables_dataset_metadata": {"target_column_spec_id": ""}, +} +AUTOML_DATASET_BUCKET = f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/automl/bank-marketing.csv" +IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [AUTOML_DATASET_BUCKET]}} + +extract_object_id = CloudAutoMLHook.extract_object_id + + +def get_target_column_spec(columns_specs: list[dict], column_name: str) -> str: + """ + Using column name returns spec of the column. + """ + for column in columns_specs: + if column["display_name"] == column_name: + return extract_object_id(column) + raise Exception(f"Unknown target column: {column_name}") + + +with models.DAG( + dag_id=DAG_ID, + schedule="@once", + start_date=datetime(2021, 1, 1), + catchup=False, + tags=["example", "automl"], + user_defined_macros={ + "get_target_column_spec": get_target_column_spec, + "target": "Class", + "extract_object_id": extract_object_id, + }, +) as dag: + create_bucket = GCSCreateBucketOperator( + task_id="create_bucket", + bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME, + storage_class="REGIONAL", + location=GCP_AUTOML_LOCATION, + ) + + move_dataset_file = GCSSynchronizeBucketsOperator( + task_id="move_dataset_to_bucket", + source_bucket=RESOURCE_DATA_BUCKET, + source_object="automl", + destination_bucket=DATA_SAMPLE_GCS_BUCKET_NAME, + destination_object="automl", + recursive=True, + ) + + # [START howto_operator_automl_create_dataset] + create_dataset = AutoMLCreateDatasetOperator( + task_id="create_dataset", + dataset=DATASET, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + dataset_id = create_dataset.output["dataset_id"] + # [END howto_operator_automl_create_dataset] + + # [START howto_operator_automl_import_data] + import_dataset_task = AutoMLImportDataOperator( + task_id="import_dataset_task", + dataset_id=dataset_id, + location=GCP_AUTOML_LOCATION, + input_config=IMPORT_INPUT_CONFIG, + ) + # [END howto_operator_automl_import_data] + + # [START howto_operator_automl_specs] + list_tables_spec_task = AutoMLTablesListTableSpecsOperator( + task_id="list_tables_spec_task", + dataset_id=dataset_id, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_automl_specs] + + # [START howto_operator_automl_column_specs] + list_columns_spec_task = AutoMLTablesListColumnSpecsOperator( + task_id="list_columns_spec_task", + dataset_id=dataset_id, + table_spec_id="{{ extract_object_id(task_instance.xcom_pull('list_tables_spec_task')[0]) }}", + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_automl_column_specs] + + # [START howto_operator_automl_update_dataset] + update = deepcopy(DATASET) + update["name"] = '{{ task_instance.xcom_pull("create_dataset")["name"] }}' + update["tables_dataset_metadata"][ # type: ignore + "target_column_spec_id" + ] = "{{ get_target_column_spec(task_instance.xcom_pull('list_columns_spec_task'), target) }}" + + update_dataset_task = AutoMLTablesUpdateDatasetOperator( + task_id="update_dataset_task", + dataset=update, + location=GCP_AUTOML_LOCATION, + ) + # [END howto_operator_automl_update_dataset] + + # [START howto_operator_list_dataset] + list_datasets_task = AutoMLListDatasetOperator( + task_id="list_datasets_task", + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_list_dataset] + + # [START howto_operator_delete_dataset] + delete_dataset = AutoMLDeleteDatasetOperator( + task_id="delete_dataset", + dataset_id="{{ task_instance.xcom_pull('list_datasets_task', key='dataset_id_list') | list }}", + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_delete_dataset] + + delete_bucket = GCSDeleteBucketOperator( + task_id="delete_bucket", bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME, trigger_rule=TriggerRule.ALL_DONE + ) + + ( + # TEST SETUP + [create_bucket >> move_dataset_file, create_dataset] + # TEST BODY + >> import_dataset_task + >> list_tables_spec_task + >> list_columns_spec_task + >> update_dataset_task + >> list_datasets_task + # TEST TEARDOWN + >> delete_dataset + >> delete_bucket + ) + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) diff --git a/tests/system/providers/google/cloud/automl/example_automl_model.py b/tests/system/providers/google/cloud/automl/example_automl_model.py new file mode 100644 index 0000000000..f08b46b024 --- /dev/null +++ b/tests/system/providers/google/cloud/automl/example_automl_model.py @@ -0,0 +1,285 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG for Google AutoML service testing model operations. +""" +from __future__ import annotations + +import os +from copy import deepcopy +from datetime import datetime + +from google.protobuf.struct_pb2 import Value + +from airflow import models +from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook +from airflow.providers.google.cloud.operators.automl import ( + AutoMLBatchPredictOperator, + AutoMLCreateDatasetOperator, + AutoMLDeleteDatasetOperator, + AutoMLDeleteModelOperator, + AutoMLDeployModelOperator, + AutoMLGetModelOperator, + AutoMLImportDataOperator, + AutoMLPredictOperator, + AutoMLTablesListColumnSpecsOperator, + AutoMLTablesListTableSpecsOperator, + AutoMLTablesUpdateDatasetOperator, + AutoMLTrainModelOperator, +) +from airflow.providers.google.cloud.operators.gcs import ( + GCSCreateBucketOperator, + GCSDeleteBucketOperator, + GCSSynchronizeBucketsOperator, +) +from airflow.utils.trigger_rule import TriggerRule + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") +DAG_ID = "automl_model" +GCP_PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default") + +GCP_AUTOML_LOCATION = "us-central1" + +DATA_SAMPLE_GCS_BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}" +RESOURCE_DATA_BUCKET = "system-tests-resources" + +DATASET_NAME = f"dataset_{DAG_ID}_{ENV_ID}" +DATASET = { + "display_name": DATASET_NAME, + "tables_dataset_metadata": {"target_column_spec_id": ""}, +} +AUTOML_DATASET_BUCKET = f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/automl/bank-marketing.csv" +IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [AUTOML_DATASET_BUCKET]}} +IMPORT_OUTPUT_CONFIG = { + "gcs_destination": {"output_uri_prefix": f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/automl"} +} + +MODEL_NAME = f"model_{DAG_ID}_{ENV_ID}" +MODEL = { + "display_name": MODEL_NAME, + "tables_model_metadata": {"train_budget_milli_node_hours": 1000}, +} + +PREDICT_VALUES = [ + Value(string_value="51"), + Value(string_value="blue-collar"), + Value(string_value="married"), + Value(string_value="primary"), + Value(string_value="no"), + Value(string_value="620"), + Value(string_value="yes"), + Value(string_value="yes"), + Value(string_value="cellular"), + Value(string_value="29"), + Value(string_value="jul"), + Value(string_value="88"), + Value(string_value="10"), + Value(string_value="-1"), + Value(string_value="0"), + Value(string_value="unknown"), +] + +extract_object_id = CloudAutoMLHook.extract_object_id + + +def get_target_column_spec(columns_specs: list[dict], column_name: str) -> str: + """ + Using column name returns spec of the column. + """ + for column in columns_specs: + if column["display_name"] == column_name: + return extract_object_id(column) + raise Exception(f"Unknown target column: {column_name}") + + +with models.DAG( + dag_id=DAG_ID, + schedule="@once", + start_date=datetime(2021, 1, 1), + catchup=False, + user_defined_macros={ + "get_target_column_spec": get_target_column_spec, + "target": "Class", + "extract_object_id": extract_object_id, + }, + tags=["example", "automl"], +) as dag: + create_bucket = GCSCreateBucketOperator( + task_id="create_bucket", + bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME, + storage_class="REGIONAL", + location=GCP_AUTOML_LOCATION, + ) + + move_dataset_file = GCSSynchronizeBucketsOperator( + task_id="move_data_to_bucket", + source_bucket=RESOURCE_DATA_BUCKET, + source_object="automl", + destination_bucket=DATA_SAMPLE_GCS_BUCKET_NAME, + destination_object="automl", + recursive=True, + ) + + create_dataset = AutoMLCreateDatasetOperator( + task_id="create_dataset", + dataset=DATASET, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + + dataset_id = create_dataset.output["dataset_id"] + MODEL["dataset_id"] = dataset_id + import_dataset = AutoMLImportDataOperator( + task_id="import_dataset", + dataset_id=dataset_id, + location=GCP_AUTOML_LOCATION, + input_config=IMPORT_INPUT_CONFIG, + ) + + list_tables_spec = AutoMLTablesListTableSpecsOperator( + task_id="list_tables_spec", + dataset_id=dataset_id, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + + list_columns_spec = AutoMLTablesListColumnSpecsOperator( + task_id="list_columns_spec", + dataset_id=dataset_id, + table_spec_id="{{ extract_object_id(task_instance.xcom_pull('list_tables_spec')[0]) }}", + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + + update = deepcopy(DATASET) + update["name"] = '{{ task_instance.xcom_pull("create_dataset")["name"] }}' + update["tables_dataset_metadata"][ # type: ignore + "target_column_spec_id" + ] = "{{ get_target_column_spec(task_instance.xcom_pull('list_columns_spec'), target) }}" + + update_dataset = AutoMLTablesUpdateDatasetOperator( + task_id="update_dataset", + dataset=update, + location=GCP_AUTOML_LOCATION, + ) + + # [START howto_operator_automl_create_model] + create_model = AutoMLTrainModelOperator( + task_id="create_model", + model=MODEL, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + model_id = create_model.output["model_id"] + # [END howto_operator_automl_create_model] + + # [START howto_operator_get_model] + get_model = AutoMLGetModelOperator( + task_id="get_model", + model_id=model_id, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_get_model] + + # [START howto_operator_deploy_model] + deploy_model = AutoMLDeployModelOperator( + task_id="deploy_model", + model_id=model_id, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_deploy_model] + + # [START howto_operator_prediction] + predict_task = AutoMLPredictOperator( + task_id="predict_task", + model_id=model_id, + payload={ + "row": { + "values": PREDICT_VALUES, + } + }, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_prediction] + + # [START howto_operator_batch_prediction] + batch_predict_task = AutoMLBatchPredictOperator( + task_id="batch_predict_task", + model_id=model_id, + input_config=IMPORT_INPUT_CONFIG, + output_config=IMPORT_OUTPUT_CONFIG, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_batch_prediction] + + # [START howto_operator_automl_delete_model] + delete_model = AutoMLDeleteModelOperator( + task_id="delete_model", + model_id=model_id, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_automl_delete_model] + + delete_dataset = AutoMLDeleteDatasetOperator( + task_id="delete_dataset", + dataset_id=dataset_id, + location=GCP_AUTOML_LOCATION, + project_id=GCP_PROJECT_ID, + trigger_rule=TriggerRule.ALL_DONE, + ) + + delete_bucket = GCSDeleteBucketOperator( + task_id="delete_bucket", bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME, trigger_rule=TriggerRule.ALL_DONE + ) + + ( + # TEST SETUP + [create_bucket >> move_dataset_file, create_dataset] + >> import_dataset + >> list_tables_spec + >> list_columns_spec + >> update_dataset + # TEST BODY + >> create_model + >> get_model + >> deploy_model + >> predict_task + >> batch_predict_task + # TEST TEARDOWN + >> delete_model + >> delete_dataset + >> delete_bucket + ) + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) diff --git a/tests/system/providers/google/cloud/automl/resources/__init__.py b/tests/system/providers/google/cloud/automl/resources/__init__.py new file mode 100644 index 0000000000..13a83393a9 --- /dev/null +++ b/tests/system/providers/google/cloud/automl/resources/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License.