This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 90e6277da6 Auto ML assets (#25466)
90e6277da6 is described below
commit 90e6277da6b4102cf565134739af10bafa9d3894
Author: Maksim <[email protected]>
AuthorDate: Mon Jan 23 18:19:28 2023 +0300
Auto ML assets (#25466)
---
.../cloud/example_dags/example_automl_tables.py | 319 ---------------------
airflow/providers/google/cloud/links/automl.py | 163 +++++++++++
airflow/providers/google/cloud/operators/automl.py | 103 ++++++-
airflow/providers/google/provider.yaml | 5 +
.../operators/cloud/automl.rst | 26 +-
.../google/cloud/operators/test_automl.py | 32 +--
.../google/cloud/operators/test_automl_system.py | 41 ---
.../google/cloud/utils/gcp_authenticator.py | 1 -
.../google/cloud/automl/example_automl_dataset.py | 201 +++++++++++++
.../google/cloud/automl/example_automl_model.py | 285 ++++++++++++++++++
.../google/cloud/automl/resources/__init__.py | 16 ++
11 files changed, 797 insertions(+), 395 deletions(-)
diff --git
a/airflow/providers/google/cloud/example_dags/example_automl_tables.py
b/airflow/providers/google/cloud/example_dags/example_automl_tables.py
deleted file mode 100644
index 89006402f7..0000000000
--- a/airflow/providers/google/cloud/example_dags/example_automl_tables.py
+++ /dev/null
@@ -1,319 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-Example Airflow DAG that uses Google AutoML services.
-"""
-from __future__ import annotations
-
-import os
-from copy import deepcopy
-from datetime import datetime
-from typing import cast
-
-from airflow import models
-from airflow.models.xcom_arg import XComArg
-from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
-from airflow.providers.google.cloud.operators.automl import (
- AutoMLBatchPredictOperator,
- AutoMLCreateDatasetOperator,
- AutoMLDeleteDatasetOperator,
- AutoMLDeleteModelOperator,
- AutoMLDeployModelOperator,
- AutoMLGetModelOperator,
- AutoMLImportDataOperator,
- AutoMLListDatasetOperator,
- AutoMLPredictOperator,
- AutoMLTablesListColumnSpecsOperator,
- AutoMLTablesListTableSpecsOperator,
- AutoMLTablesUpdateDatasetOperator,
- AutoMLTrainModelOperator,
-)
-
-START_DATE = datetime(2021, 1, 1)
-
-GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "your-project-id")
-GCP_AUTOML_LOCATION = os.environ.get("GCP_AUTOML_LOCATION", "us-central1")
-GCP_AUTOML_DATASET_BUCKET = os.environ.get(
- "GCP_AUTOML_DATASET_BUCKET", "gs://INVALID BUCKET NAME/bank-marketing.csv"
-)
-TARGET = os.environ.get("GCP_AUTOML_TARGET", "Deposit")
-
-# Example values
-MODEL_ID = "TBL123456"
-DATASET_ID = "TBL123456"
-
-# Example model
-MODEL = {
- "display_name": "auto_model_1",
- "dataset_id": DATASET_ID,
- "tables_model_metadata": {"train_budget_milli_node_hours": 1000},
-}
-
-# Example dataset
-DATASET = {
- "display_name": "test_set",
- "tables_dataset_metadata": {"target_column_spec_id": ""},
-}
-
-IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris":
[GCP_AUTOML_DATASET_BUCKET]}}
-
-extract_object_id = CloudAutoMLHook.extract_object_id
-
-
-def get_target_column_spec(columns_specs: list[dict], column_name: str) -> str:
- """
- Using column name returns spec of the column.
- """
- for column in columns_specs:
- if column["display_name"] == column_name:
- return extract_object_id(column)
- raise Exception(f"Unknown target column: {column_name}")
-
-
-# Example DAG to create dataset, train model_id and deploy it.
-with models.DAG(
- "example_create_and_deploy",
- start_date=START_DATE,
- catchup=False,
- user_defined_macros={
- "get_target_column_spec": get_target_column_spec,
- "target": TARGET,
- "extract_object_id": extract_object_id,
- },
- tags=["example"],
-) as create_deploy_dag:
- # [START howto_operator_automl_create_dataset]
- create_dataset_task = AutoMLCreateDatasetOperator(
- task_id="create_dataset_task",
- dataset=DATASET,
- location=GCP_AUTOML_LOCATION,
- project_id=GCP_PROJECT_ID,
- )
-
- dataset_id = cast(str, XComArg(create_dataset_task, key="dataset_id"))
- # [END howto_operator_automl_create_dataset]
-
- MODEL["dataset_id"] = dataset_id
-
- # [START howto_operator_automl_import_data]
- import_dataset_task = AutoMLImportDataOperator(
- task_id="import_dataset_task",
- dataset_id=dataset_id,
- location=GCP_AUTOML_LOCATION,
- input_config=IMPORT_INPUT_CONFIG,
- )
- # [END howto_operator_automl_import_data]
-
- # [START howto_operator_automl_specs]
- list_tables_spec_task = AutoMLTablesListTableSpecsOperator(
- task_id="list_tables_spec_task",
- dataset_id=dataset_id,
- location=GCP_AUTOML_LOCATION,
- project_id=GCP_PROJECT_ID,
- )
- # [END howto_operator_automl_specs]
-
- # [START howto_operator_automl_column_specs]
- list_columns_spec_task = AutoMLTablesListColumnSpecsOperator(
- task_id="list_columns_spec_task",
- dataset_id=dataset_id,
- table_spec_id="{{
extract_object_id(task_instance.xcom_pull('list_tables_spec_task')[0]) }}",
- location=GCP_AUTOML_LOCATION,
- project_id=GCP_PROJECT_ID,
- )
- # [END howto_operator_automl_column_specs]
-
- # [START howto_operator_automl_update_dataset]
- update = deepcopy(DATASET)
- update["name"] = '{{
task_instance.xcom_pull("create_dataset_task")["name"] }}'
- update["tables_dataset_metadata"][ # type: ignore
- "target_column_spec_id"
- ] = "{{
get_target_column_spec(task_instance.xcom_pull('list_columns_spec_task'),
target) }}"
-
- update_dataset_task = AutoMLTablesUpdateDatasetOperator(
- task_id="update_dataset_task",
- dataset=update,
- location=GCP_AUTOML_LOCATION,
- )
- # [END howto_operator_automl_update_dataset]
-
- # [START howto_operator_automl_create_model]
- create_model_task = AutoMLTrainModelOperator(
- task_id="create_model_task",
- model=MODEL,
- location=GCP_AUTOML_LOCATION,
- project_id=GCP_PROJECT_ID,
- )
-
- model_id = cast(str, XComArg(create_model_task, key="model_id"))
- # [END howto_operator_automl_create_model]
-
- # [START howto_operator_automl_delete_model]
- delete_model_task = AutoMLDeleteModelOperator(
- task_id="delete_model_task",
- model_id=model_id,
- location=GCP_AUTOML_LOCATION,
- project_id=GCP_PROJECT_ID,
- )
- # [END howto_operator_automl_delete_model]
-
- delete_datasets_task = AutoMLDeleteDatasetOperator(
- task_id="delete_datasets_task",
- dataset_id=dataset_id,
- location=GCP_AUTOML_LOCATION,
- project_id=GCP_PROJECT_ID,
- )
-
- (
- import_dataset_task
- >> list_tables_spec_task
- >> list_columns_spec_task
- >> update_dataset_task
- >> create_model_task
- )
- delete_model_task >> delete_datasets_task
-
- # Task dependencies created via `XComArgs`:
- # create_dataset_task >> import_dataset_task
- # create_dataset_task >> list_tables_spec_task
- # create_dataset_task >> list_columns_spec_task
- # create_dataset_task >> create_model_task
- # create_model_task >> delete_model_task
- # create_dataset_task >> delete_datasets_task
-
-
-# Example DAG for AutoML datasets operations
-with models.DAG(
- "example_automl_dataset",
- start_date=START_DATE,
- catchup=False,
- user_defined_macros={"extract_object_id": extract_object_id},
-) as example_dag:
- create_dataset_task2 = AutoMLCreateDatasetOperator(
- task_id="create_dataset_task",
- dataset=DATASET,
- location=GCP_AUTOML_LOCATION,
- project_id=GCP_PROJECT_ID,
- )
-
- dataset_id = cast(str, XComArg(create_dataset_task2, key="dataset_id"))
-
- import_dataset_task = AutoMLImportDataOperator(
- task_id="import_dataset_task",
- dataset_id=dataset_id,
- location=GCP_AUTOML_LOCATION,
- input_config=IMPORT_INPUT_CONFIG,
- )
-
- list_tables_spec_task = AutoMLTablesListTableSpecsOperator(
- task_id="list_tables_spec_task",
- dataset_id=dataset_id,
- location=GCP_AUTOML_LOCATION,
- project_id=GCP_PROJECT_ID,
- )
-
- list_columns_spec_task = AutoMLTablesListColumnSpecsOperator(
- task_id="list_columns_spec_task",
- dataset_id=dataset_id,
- table_spec_id="{{
extract_object_id(task_instance.xcom_pull('list_tables_spec_task')[0]) }}",
- location=GCP_AUTOML_LOCATION,
- project_id=GCP_PROJECT_ID,
- )
-
- # [START howto_operator_list_dataset]
- list_datasets_task = AutoMLListDatasetOperator(
- task_id="list_datasets_task",
- location=GCP_AUTOML_LOCATION,
- project_id=GCP_PROJECT_ID,
- )
- # [END howto_operator_list_dataset]
-
- # [START howto_operator_delete_dataset]
- delete_datasets_task = AutoMLDeleteDatasetOperator(
- task_id="delete_datasets_task",
- dataset_id="{{ task_instance.xcom_pull('list_datasets_task',
key='dataset_id_list') | list }}",
- location=GCP_AUTOML_LOCATION,
- project_id=GCP_PROJECT_ID,
- )
- # [END howto_operator_delete_dataset]
-
- (
- import_dataset_task
- >> list_tables_spec_task
- >> list_columns_spec_task
- >> list_datasets_task
- >> delete_datasets_task
- )
-
- # Task dependencies created via `XComArgs`:
- # create_dataset_task >> import_dataset_task
- # create_dataset_task >> list_tables_spec_task
- # create_dataset_task >> list_columns_spec_task
-
-
-with models.DAG(
- "example_gcp_get_deploy",
- start_date=START_DATE,
- catchup=False,
- tags=["example"],
-) as get_deploy_dag:
- # [START howto_operator_get_model]
- get_model_task = AutoMLGetModelOperator(
- task_id="get_model_task",
- model_id=MODEL_ID,
- location=GCP_AUTOML_LOCATION,
- project_id=GCP_PROJECT_ID,
- )
- # [END howto_operator_get_model]
-
- # [START howto_operator_deploy_model]
- deploy_model_task = AutoMLDeployModelOperator(
- task_id="deploy_model_task",
- model_id=MODEL_ID,
- location=GCP_AUTOML_LOCATION,
- project_id=GCP_PROJECT_ID,
- )
- # [END howto_operator_deploy_model]
-
-
-with models.DAG(
- "example_gcp_predict",
- start_date=START_DATE,
- catchup=False,
- tags=["example"],
-) as predict_dag:
- # [START howto_operator_prediction]
- predict_task = AutoMLPredictOperator(
- task_id="predict_task",
- model_id=MODEL_ID,
- payload={}, # Add your own payload, the used model_id must be deployed
- location=GCP_AUTOML_LOCATION,
- project_id=GCP_PROJECT_ID,
- )
- # [END howto_operator_prediction]
-
- # [START howto_operator_batch_prediction]
- batch_predict_task = AutoMLBatchPredictOperator(
- task_id="batch_predict_task",
- model_id=MODEL_ID,
- input_config={}, # Add your config
- output_config={}, # Add your config
- location=GCP_AUTOML_LOCATION,
- project_id=GCP_PROJECT_ID,
- )
- # [END howto_operator_batch_prediction]
diff --git a/airflow/providers/google/cloud/links/automl.py
b/airflow/providers/google/cloud/links/automl.py
new file mode 100644
index 0000000000..f2deee6642
--- /dev/null
+++ b/airflow/providers/google/cloud/links/automl.py
@@ -0,0 +1,163 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""This module contains Google AutoML links."""
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from airflow.providers.google.cloud.links.base import BaseGoogleLink
+
+if TYPE_CHECKING:
+ from airflow.utils.context import Context
+
+AUTOML_BASE_LINK = "https://console.cloud.google.com/automl-tables"
+AUTOML_DATASET_LINK = (
+ AUTOML_BASE_LINK +
"/locations/{location}/datasets/{dataset_id}/schemav2?project={project_id}"
+)
+AUTOML_DATASET_LIST_LINK = AUTOML_BASE_LINK + "/datasets?project={project_id}"
+AUTOML_MODEL_LINK = (
+ AUTOML_BASE_LINK
+ +
"/locations/{location}/datasets/{dataset_id};modelId={model_id}/evaluate?project={project_id}"
+)
+AUTOML_MODEL_TRAIN_LINK = (
+ AUTOML_BASE_LINK +
"/locations/{location}/datasets/{dataset_id}/train?project={project_id}"
+)
+AUTOML_MODEL_PREDICT_LINK = (
+ AUTOML_BASE_LINK
+ +
"/locations/{location}/datasets/{dataset_id};modelId={model_id}/predict?project={project_id}"
+)
+
+
+class AutoMLDatasetLink(BaseGoogleLink):
+ """Helper class for constructing AutoML Dataset link"""
+
+ name = "AutoML Dataset"
+ key = "automl_dataset"
+ format_str = AUTOML_DATASET_LINK
+
+ @staticmethod
+ def persist(
+ context: Context,
+ task_instance,
+ dataset_id: str,
+ project_id: str,
+ ):
+ task_instance.xcom_push(
+ context,
+ key=AutoMLDatasetLink.key,
+ value={"location": task_instance.location, "dataset_id":
dataset_id, "project_id": project_id},
+ )
+
+
+class AutoMLDatasetListLink(BaseGoogleLink):
+ """Helper class for constructing AutoML Dataset List link"""
+
+ name = "AutoML Dataset List"
+ key = "automl_dataset_list"
+ format_str = AUTOML_DATASET_LIST_LINK
+
+ @staticmethod
+ def persist(
+ context: Context,
+ task_instance,
+ project_id: str,
+ ):
+ task_instance.xcom_push(
+ context,
+ key=AutoMLDatasetListLink.key,
+ value={
+ "project_id": project_id,
+ },
+ )
+
+
+class AutoMLModelLink(BaseGoogleLink):
+ """Helper class for constructing AutoML Model link"""
+
+ name = "AutoML Model"
+ key = "automl_model"
+ format_str = AUTOML_MODEL_LINK
+
+ @staticmethod
+ def persist(
+ context: Context,
+ task_instance,
+ dataset_id: str,
+ model_id: str,
+ project_id: str,
+ ):
+ task_instance.xcom_push(
+ context,
+ key=AutoMLModelLink.key,
+ value={
+ "location": task_instance.location,
+ "dataset_id": dataset_id,
+ "model_id": model_id,
+ "project_id": project_id,
+ },
+ )
+
+
+class AutoMLModelTrainLink(BaseGoogleLink):
+ """Helper class for constructing AutoML Model Train link"""
+
+ name = "AutoML Model Train"
+ key = "automl_model_train"
+ format_str = AUTOML_MODEL_TRAIN_LINK
+
+ @staticmethod
+ def persist(
+ context: Context,
+ task_instance,
+ project_id: str,
+ ):
+ task_instance.xcom_push(
+ context,
+ key=AutoMLModelTrainLink.key,
+ value={
+ "location": task_instance.location,
+ "dataset_id": task_instance.model["dataset_id"],
+ "project_id": project_id,
+ },
+ )
+
+
+class AutoMLModelPredictLink(BaseGoogleLink):
+ """Helper class for constructing AutoML Model Predict link"""
+
+ name = "AutoML Model Predict"
+ key = "automl_model_predict"
+ format_str = AUTOML_MODEL_PREDICT_LINK
+
+ @staticmethod
+ def persist(
+ context: Context,
+ task_instance,
+ model_id: str,
+ project_id: str,
+ ):
+ task_instance.xcom_push(
+ context,
+ key=AutoMLModelPredictLink.key,
+ value={
+ "location": task_instance.location,
+ "dataset_id": "-",
+ "model_id": model_id,
+ "project_id": project_id,
+ },
+ )
diff --git a/airflow/providers/google/cloud/operators/automl.py
b/airflow/providers/google/cloud/operators/automl.py
index ff1ee00c3e..b1da9d3b77 100644
--- a/airflow/providers/google/cloud/operators/automl.py
+++ b/airflow/providers/google/cloud/operators/automl.py
@@ -34,6 +34,13 @@ from google.cloud.automl_v1beta1 import (
from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
+from airflow.providers.google.cloud.links.automl import (
+ AutoMLDatasetLink,
+ AutoMLDatasetListLink,
+ AutoMLModelLink,
+ AutoMLModelPredictLink,
+ AutoMLModelTrainLink,
+)
if TYPE_CHECKING:
from airflow.utils.context import Context
@@ -75,6 +82,10 @@ class AutoMLTrainModelOperator(BaseOperator):
"project_id",
"impersonation_chain",
)
+ operator_extra_links = (
+ AutoMLModelTrainLink(),
+ AutoMLModelLink(),
+ )
def __init__(
self,
@@ -114,11 +125,22 @@ class AutoMLTrainModelOperator(BaseOperator):
timeout=self.timeout,
metadata=self.metadata,
)
+ project_id = self.project_id or hook.project_id
+ if project_id:
+ AutoMLModelTrainLink.persist(context=context, task_instance=self,
project_id=project_id)
result = Model.to_dict(operation.result())
model_id = hook.extract_object_id(result)
self.log.info("Model created: %s", model_id)
self.xcom_push(context, key="model_id", value=model_id)
+ if project_id:
+ AutoMLModelLink.persist(
+ context=context,
+ task_instance=self,
+ dataset_id=self.model["dataset_id"] or "-",
+ model_id=model_id,
+ project_id=project_id,
+ )
return result
@@ -158,6 +180,7 @@ class AutoMLPredictOperator(BaseOperator):
"project_id",
"impersonation_chain",
)
+ operator_extra_links = (AutoMLModelPredictLink(),)
def __init__(
self,
@@ -202,6 +225,14 @@ class AutoMLPredictOperator(BaseOperator):
timeout=self.timeout,
metadata=self.metadata,
)
+ project_id = self.project_id or hook.project_id
+ if project_id:
+ AutoMLModelPredictLink.persist(
+ context=context,
+ task_instance=self,
+ model_id=self.model_id,
+ project_id=project_id,
+ )
return PredictResponse.to_dict(result)
@@ -252,6 +283,7 @@ class AutoMLBatchPredictOperator(BaseOperator):
"project_id",
"impersonation_chain",
)
+ operator_extra_links = (AutoMLModelPredictLink(),)
def __init__(
self,
@@ -302,6 +334,14 @@ class AutoMLBatchPredictOperator(BaseOperator):
)
result = BatchPredictResult.to_dict(operation.result())
self.log.info("Batch prediction ready.")
+ project_id = self.project_id or hook.project_id
+ if project_id:
+ AutoMLModelPredictLink.persist(
+ context=context,
+ task_instance=self,
+ model_id=self.model_id,
+ project_id=project_id,
+ )
return result
@@ -341,6 +381,7 @@ class AutoMLCreateDatasetOperator(BaseOperator):
"project_id",
"impersonation_chain",
)
+ operator_extra_links = (AutoMLDatasetLink(),)
def __init__(
self,
@@ -385,6 +426,14 @@ class AutoMLCreateDatasetOperator(BaseOperator):
self.log.info("Creating completed. Dataset id: %s", dataset_id)
self.xcom_push(context, key="dataset_id", value=dataset_id)
+ project_id = self.project_id or hook.project_id
+ if project_id:
+ AutoMLDatasetLink.persist(
+ context=context,
+ task_instance=self,
+ dataset_id=dataset_id,
+ project_id=project_id,
+ )
return result
@@ -426,6 +475,7 @@ class AutoMLImportDataOperator(BaseOperator):
"project_id",
"impersonation_chain",
)
+ operator_extra_links = (AutoMLDatasetLink(),)
def __init__(
self,
@@ -470,6 +520,14 @@ class AutoMLImportDataOperator(BaseOperator):
)
operation.result()
self.log.info("Import completed")
+ project_id = self.project_id or hook.project_id
+ if project_id:
+ AutoMLDatasetLink.persist(
+ context=context,
+ task_instance=self,
+ dataset_id=self.dataset_id,
+ project_id=project_id,
+ )
class AutoMLTablesListColumnSpecsOperator(BaseOperator):
@@ -518,6 +576,7 @@ class AutoMLTablesListColumnSpecsOperator(BaseOperator):
"project_id",
"impersonation_chain",
)
+ operator_extra_links = (AutoMLDatasetLink(),)
def __init__(
self,
@@ -570,7 +629,14 @@ class AutoMLTablesListColumnSpecsOperator(BaseOperator):
)
result = [ColumnSpec.to_dict(spec) for spec in page_iterator]
self.log.info("Columns specs obtained.")
-
+ project_id = self.project_id or hook.project_id
+ if project_id:
+ AutoMLDatasetLink.persist(
+ context=context,
+ task_instance=self,
+ dataset_id=self.dataset_id,
+ project_id=project_id,
+ )
return result
@@ -610,6 +676,7 @@ class AutoMLTablesUpdateDatasetOperator(BaseOperator):
"location",
"impersonation_chain",
)
+ operator_extra_links = (AutoMLDatasetLink(),)
def __init__(
self,
@@ -649,6 +716,14 @@ class AutoMLTablesUpdateDatasetOperator(BaseOperator):
metadata=self.metadata,
)
self.log.info("Dataset updated.")
+ project_id = hook.project_id
+ if project_id:
+ AutoMLDatasetLink.persist(
+ context=context,
+ task_instance=self,
+ dataset_id=hook.extract_object_id(self.dataset),
+ project_id=project_id,
+ )
return Dataset.to_dict(result)
@@ -687,6 +762,7 @@ class AutoMLGetModelOperator(BaseOperator):
"project_id",
"impersonation_chain",
)
+ operator_extra_links = (AutoMLModelLink(),)
def __init__(
self,
@@ -725,7 +801,17 @@ class AutoMLGetModelOperator(BaseOperator):
timeout=self.timeout,
metadata=self.metadata,
)
- return Model.to_dict(result)
+ model = Model.to_dict(result)
+ project_id = self.project_id or hook.project_id
+ if project_id:
+ AutoMLModelLink.persist(
+ context=context,
+ task_instance=self,
+ dataset_id=model["dataset_id"],
+ model_id=self.model_id,
+ project_id=project_id,
+ )
+ return model
class AutoMLDeleteModelOperator(BaseOperator):
@@ -935,6 +1021,7 @@ class AutoMLTablesListTableSpecsOperator(BaseOperator):
"project_id",
"impersonation_chain",
)
+ operator_extra_links = (AutoMLDatasetLink(),)
def __init__(
self,
@@ -982,6 +1069,14 @@ class AutoMLTablesListTableSpecsOperator(BaseOperator):
result = [TableSpec.to_dict(spec) for spec in page_iterator]
self.log.info(result)
self.log.info("Table specs obtained.")
+ project_id = self.project_id or hook.project_id
+ if project_id:
+ AutoMLDatasetLink.persist(
+ context=context,
+ task_instance=self,
+ dataset_id=self.dataset_id,
+ project_id=project_id,
+ )
return result
@@ -1017,6 +1112,7 @@ class AutoMLListDatasetOperator(BaseOperator):
"project_id",
"impersonation_chain",
)
+ operator_extra_links = (AutoMLDatasetListLink(),)
def __init__(
self,
@@ -1060,6 +1156,9 @@ class AutoMLListDatasetOperator(BaseOperator):
key="dataset_id_list",
value=[hook.extract_object_id(d) for d in result],
)
+ project_id = self.project_id or hook.project_id
+ if project_id:
+ AutoMLDatasetListLink.persist(context=context, task_instance=self,
project_id=project_id)
return result
diff --git a/airflow/providers/google/provider.yaml
b/airflow/providers/google/provider.yaml
index 43c1f53e44..6ee301dc48 100644
--- a/airflow/providers/google/provider.yaml
+++ b/airflow/providers/google/provider.yaml
@@ -1034,6 +1034,11 @@ extra-links:
- airflow.providers.google.cloud.links.cloud_build.CloudBuildListLink
- airflow.providers.google.cloud.links.cloud_build.CloudBuildTriggersListLink
-
airflow.providers.google.cloud.links.cloud_build.CloudBuildTriggerDetailsLink
+ - airflow.providers.google.cloud.links.automl.AutoMLDatasetLink
+ - airflow.providers.google.cloud.links.automl.AutoMLDatasetListLink
+ - airflow.providers.google.cloud.links.automl.AutoMLModelLink
+ - airflow.providers.google.cloud.links.automl.AutoMLModelTrainLink
+ - airflow.providers.google.cloud.links.automl.AutoMLModelPredictLink
- airflow.providers.google.cloud.links.life_sciences.LifeSciencesLink
-
airflow.providers.google.cloud.links.cloud_functions.CloudFunctionsDetailsLink
- airflow.providers.google.cloud.links.cloud_functions.CloudFunctionsListLink
diff --git a/docs/apache-airflow-providers-google/operators/cloud/automl.rst
b/docs/apache-airflow-providers-google/operators/cloud/automl.rst
index 4e93a4f5e3..fd4fbbc2d2 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/automl.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/automl.rst
@@ -41,7 +41,7 @@ To create a Google AutoML dataset you can use
:class:`~airflow.providers.google.cloud.operators.automl.AutoMLCreateDatasetOperator`.
The operator returns dataset id in :ref:`XCom <concepts:xcom>` under
``dataset_id`` key.
-.. exampleinclude::
/../../airflow/providers/google/cloud/example_dags/example_automl_tables.py
+.. exampleinclude::
/../../tests/system/providers/google/cloud/automl/example_automl_dataset.py
:language: python
:dedent: 4
:start-after: [START howto_operator_automl_create_dataset]
@@ -50,7 +50,7 @@ The operator returns dataset id in :ref:`XCom
<concepts:xcom>` under ``dataset_i
After creating a dataset you can use it to import some data using
:class:`~airflow.providers.google.cloud.operators.automl.AutoMLImportDataOperator`.
-.. exampleinclude::
/../../airflow/providers/google/cloud/example_dags/example_automl_tables.py
+.. exampleinclude::
/../../tests/system/providers/google/cloud/automl/example_automl_dataset.py
:language: python
:dedent: 4
:start-after: [START howto_operator_automl_import_data]
@@ -59,7 +59,7 @@ After creating a dataset you can use it to import some data
using
To update dataset you can use
:class:`~airflow.providers.google.cloud.operators.automl.AutoMLTablesUpdateDatasetOperator`.
-.. exampleinclude::
/../../airflow/providers/google/cloud/example_dags/example_automl_tables.py
+.. exampleinclude::
/../../tests/system/providers/google/cloud/automl/example_automl_dataset.py
:language: python
:dedent: 4
:start-after: [START howto_operator_automl_update_dataset]
@@ -74,7 +74,7 @@ Listing Table And Columns Specs
To list table specs you can use
:class:`~airflow.providers.google.cloud.operators.automl.AutoMLTablesListTableSpecsOperator`.
-.. exampleinclude::
/../../airflow/providers/google/cloud/example_dags/example_automl_tables.py
+.. exampleinclude::
/../../tests/system/providers/google/cloud/automl/example_automl_dataset.py
:language: python
:dedent: 4
:start-after: [START howto_operator_automl_specs]
@@ -83,7 +83,7 @@ To list table specs you can use
To list column specs you can use
:class:`~airflow.providers.google.cloud.operators.automl.AutoMLTablesListColumnSpecsOperator`.
-.. exampleinclude::
/../../airflow/providers/google/cloud/example_dags/example_automl_tables.py
+.. exampleinclude::
/../../tests/system/providers/google/cloud/automl/example_automl_dataset.py
:language: python
:dedent: 4
:start-after: [START howto_operator_automl_column_specs]
@@ -102,7 +102,7 @@ To create a Google AutoML model you can use
The operator will wait for the operation to complete. Additionally the operator
returns the id of model in :ref:`XCom <concepts:xcom>` under ``model_id`` key.
-.. exampleinclude::
/../../airflow/providers/google/cloud/example_dags/example_automl_tables.py
+.. exampleinclude::
/../../tests/system/providers/google/cloud/automl/example_automl_model.py
:language: python
:dedent: 4
:start-after: [START howto_operator_automl_create_model]
@@ -111,7 +111,7 @@ returns the id of model in :ref:`XCom <concepts:xcom>`
under ``model_id`` key.
To get existing model one can use
:class:`~airflow.providers.google.cloud.operators.automl.AutoMLGetModelOperator`.
-.. exampleinclude::
/../../airflow/providers/google/cloud/example_dags/example_automl_tables.py
+.. exampleinclude::
/../../tests/system/providers/google/cloud/automl/example_automl_model.py
:language: python
:dedent: 4
:start-after: [START howto_operator_get_model]
@@ -120,7 +120,7 @@ To get existing model one can use
Once a model is created it could be deployed using
:class:`~airflow.providers.google.cloud.operators.automl.AutoMLDeployModelOperator`.
-.. exampleinclude::
/../../airflow/providers/google/cloud/example_dags/example_automl_tables.py
+.. exampleinclude::
/../../tests/system/providers/google/cloud/automl/example_automl_model.py
:language: python
:dedent: 4
:start-after: [START howto_operator_deploy_model]
@@ -129,7 +129,7 @@ Once a model is created it could be deployed using
If you wish to delete a model you can use
:class:`~airflow.providers.google.cloud.operators.automl.AutoMLDeleteModelOperator`.
-.. exampleinclude::
/../../airflow/providers/google/cloud/example_dags/example_automl_tables.py
+.. exampleinclude::
/../../tests/system/providers/google/cloud/automl/example_automl_model.py
:language: python
:dedent: 4
:start-after: [START howto_operator_automl_delete_model]
@@ -146,13 +146,13 @@ To obtain predictions from Google Cloud AutoML model you
can use
:class:`~airflow.providers.google.cloud.operators.automl.AutoMLBatchPredictOperator`.
In the first case
the model must be deployed.
-.. exampleinclude::
/../../airflow/providers/google/cloud/example_dags/example_automl_tables.py
+.. exampleinclude::
/../../tests/system/providers/google/cloud/automl/example_automl_model.py
:language: python
:dedent: 4
:start-after: [START howto_operator_prediction]
:end-before: [END howto_operator_prediction]
-.. exampleinclude::
/../../airflow/providers/google/cloud/example_dags/example_automl_tables.py
+.. exampleinclude::
/../../tests/system/providers/google/cloud/automl/example_automl_model.py
:language: python
:dedent: 4
:start-after: [START howto_operator_batch_prediction]
@@ -168,7 +168,7 @@ You can get a list of AutoML models using
:class:`~airflow.providers.google.cloud.operators.automl.AutoMLListDatasetOperator`.
The operator returns list
of datasets ids in :ref:`XCom <concepts:xcom>` under ``dataset_id_list`` key.
-.. exampleinclude::
/../../airflow/providers/google/cloud/example_dags/example_automl_tables.py
+.. exampleinclude::
/../../tests/system/providers/google/cloud/automl/example_automl_dataset.py
:language: python
:dedent: 4
:start-after: [START howto_operator_list_dataset]
@@ -177,7 +177,7 @@ of datasets ids in :ref:`XCom <concepts:xcom>` under
``dataset_id_list`` key.
To delete a model you can use
:class:`~airflow.providers.google.cloud.operators.automl.AutoMLDeleteDatasetOperator`.
The delete operator allows also to pass list or coma separated string of
datasets ids to be deleted.
-.. exampleinclude::
/../../airflow/providers/google/cloud/example_dags/example_automl_tables.py
+.. exampleinclude::
/../../tests/system/providers/google/cloud/automl/example_automl_dataset.py
:language: python
:dedent: 4
:start-after: [START howto_operator_delete_dataset]
diff --git a/tests/providers/google/cloud/operators/test_automl.py
b/tests/providers/google/cloud/operators/test_automl.py
index b02a0e2509..a39e368698 100644
--- a/tests/providers/google/cloud/operators/test_automl.py
+++ b/tests/providers/google/cloud/operators/test_automl.py
@@ -68,9 +68,8 @@ extract_object_id = CloudAutoMLHook.extract_object_id
class TestAutoMLTrainModelOperator(unittest.TestCase):
-
@mock.patch("airflow.providers.google.cloud.operators.automl.AutoMLTrainModelOperator.xcom_push")
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
- def test_execute(self, mock_hook, mock_xcom):
+ def test_execute(self, mock_hook):
mock_hook.return_value.create_model.return_value.result.return_value =
Model(name=MODEL_PATH)
mock_hook.return_value.extract_object_id = extract_object_id
op = AutoMLTrainModelOperator(
@@ -79,7 +78,7 @@ class TestAutoMLTrainModelOperator(unittest.TestCase):
project_id=GCP_PROJECT_ID,
task_id=TASK_ID,
)
- op.execute(context=None)
+ op.execute(context=mock.MagicMock())
mock_hook.return_value.create_model.assert_called_once_with(
model=MODEL,
location=GCP_LOCATION,
@@ -88,7 +87,6 @@ class TestAutoMLTrainModelOperator(unittest.TestCase):
timeout=None,
metadata=(),
)
- mock_xcom.assert_called_once_with(None, key="model_id", value=MODEL_ID)
class TestAutoMLBatchPredictOperator(unittest.TestCase):
@@ -106,7 +104,7 @@ class TestAutoMLBatchPredictOperator(unittest.TestCase):
task_id=TASK_ID,
prediction_params={},
)
- op.execute(context=None)
+ op.execute(context=mock.MagicMock())
mock_hook.return_value.batch_predict.assert_called_once_with(
input_config=INPUT_CONFIG,
location=GCP_LOCATION,
@@ -133,7 +131,7 @@ class TestAutoMLPredictOperator(unittest.TestCase):
task_id=TASK_ID,
operation_params={"TEST_KEY": "TEST_VALUE"},
)
- op.execute(context=None)
+ op.execute(context=mock.MagicMock())
mock_hook.return_value.predict.assert_called_once_with(
location=GCP_LOCATION,
metadata=(),
@@ -147,9 +145,8 @@ class TestAutoMLPredictOperator(unittest.TestCase):
class TestAutoMLCreateImportOperator(unittest.TestCase):
-
@mock.patch("airflow.providers.google.cloud.operators.automl.AutoMLCreateDatasetOperator.xcom_push")
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
- def test_execute(self, mock_hook, mock_xcom):
+ def test_execute(self, mock_hook):
mock_hook.return_value.create_dataset.return_value =
Dataset(name=DATASET_PATH)
mock_hook.return_value.extract_object_id = extract_object_id
@@ -159,7 +156,7 @@ class TestAutoMLCreateImportOperator(unittest.TestCase):
project_id=GCP_PROJECT_ID,
task_id=TASK_ID,
)
- op.execute(context=None)
+ op.execute(context=mock.MagicMock())
mock_hook.return_value.create_dataset.assert_called_once_with(
dataset=DATASET,
location=GCP_LOCATION,
@@ -168,7 +165,6 @@ class TestAutoMLCreateImportOperator(unittest.TestCase):
retry=DEFAULT,
timeout=None,
)
- mock_xcom.assert_called_once_with(None, key="dataset_id",
value=DATASET_ID)
class TestAutoMLListColumnsSpecsOperator(unittest.TestCase):
@@ -188,7 +184,7 @@ class TestAutoMLListColumnsSpecsOperator(unittest.TestCase):
page_size=page_size,
task_id=TASK_ID,
)
- op.execute(context=None)
+ op.execute(context=mock.MagicMock())
mock_hook.return_value.list_column_specs.assert_called_once_with(
dataset_id=DATASET_ID,
field_mask=MASK,
@@ -217,7 +213,7 @@ class TestAutoMLUpdateDatasetOperator(unittest.TestCase):
location=GCP_LOCATION,
task_id=TASK_ID,
)
- op.execute(context=None)
+ op.execute(context=mock.MagicMock())
mock_hook.return_value.update_dataset.assert_called_once_with(
dataset=dataset,
metadata=(),
@@ -239,7 +235,7 @@ class TestAutoMLGetModelOperator(unittest.TestCase):
project_id=GCP_PROJECT_ID,
task_id=TASK_ID,
)
- op.execute(context=None)
+ op.execute(context=mock.MagicMock())
mock_hook.return_value.get_model.assert_called_once_with(
location=GCP_LOCATION,
metadata=(),
@@ -303,7 +299,7 @@ class TestAutoMLDatasetImportOperator(unittest.TestCase):
input_config=INPUT_CONFIG,
task_id=TASK_ID,
)
- op.execute(context=None)
+ op.execute(context=mock.MagicMock())
mock_hook.return_value.import_data.assert_called_once_with(
input_config=INPUT_CONFIG,
location=GCP_LOCATION,
@@ -329,7 +325,7 @@ class
TestAutoMLTablesListTableSpecsOperator(unittest.TestCase):
page_size=page_size,
task_id=TASK_ID,
)
- op.execute(context=None)
+ op.execute(context=mock.MagicMock())
mock_hook.return_value.list_table_specs.assert_called_once_with(
dataset_id=DATASET_ID,
filter_=filter_,
@@ -343,11 +339,10 @@ class
TestAutoMLTablesListTableSpecsOperator(unittest.TestCase):
class TestAutoMLDatasetListOperator(unittest.TestCase):
-
@mock.patch("airflow.providers.google.cloud.operators.automl.AutoMLListDatasetOperator.xcom_push")
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
- def test_execute(self, mock_hook, mock_xcom):
+ def test_execute(self, mock_hook):
op = AutoMLListDatasetOperator(location=GCP_LOCATION,
project_id=GCP_PROJECT_ID, task_id=TASK_ID)
- op.execute(context=None)
+ op.execute(context=mock.MagicMock())
mock_hook.return_value.list_datasets.assert_called_once_with(
location=GCP_LOCATION,
metadata=(),
@@ -355,7 +350,6 @@ class TestAutoMLDatasetListOperator(unittest.TestCase):
retry=DEFAULT,
timeout=None,
)
- mock_xcom.assert_called_once_with(None, key="dataset_id_list",
value=[])
class TestAutoMLDatasetDeleteOperator(unittest.TestCase):
diff --git a/tests/providers/google/cloud/operators/test_automl_system.py
b/tests/providers/google/cloud/operators/test_automl_system.py
deleted file mode 100644
index dbda699fda..0000000000
--- a/tests/providers/google/cloud/operators/test_automl_system.py
+++ /dev/null
@@ -1,41 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-import pytest
-
-from tests.providers.google.cloud.utils.gcp_authenticator import GCP_AUTOML_KEY
-from tests.test_utils.gcp_system_helpers import CLOUD_DAG_FOLDER,
GoogleSystemTest, provide_gcp_context
-
-
[email protected]("mysql", "postgres")
[email protected]_file(GCP_AUTOML_KEY)
[email protected]_running
-class TestAutoMLDatasetOperationsSystem(GoogleSystemTest):
- @provide_gcp_context(GCP_AUTOML_KEY)
- def test_run_example_dag(self):
- self.run_dag("example_automl_dataset", CLOUD_DAG_FOLDER)
-
-
[email protected]("mysql", "postgres")
[email protected]_file(GCP_AUTOML_KEY)
[email protected]_running
-class TestAutoMLModelOperationsSystem(GoogleSystemTest):
- @provide_gcp_context(GCP_AUTOML_KEY)
- def test_run_example_dag(self):
- self.run_dag("example_create_and_deploy", CLOUD_DAG_FOLDER)
diff --git a/tests/providers/google/cloud/utils/gcp_authenticator.py
b/tests/providers/google/cloud/utils/gcp_authenticator.py
index f2c9312074..7c95e57dc3 100644
--- a/tests/providers/google/cloud/utils/gcp_authenticator.py
+++ b/tests/providers/google/cloud/utils/gcp_authenticator.py
@@ -30,7 +30,6 @@ from tests.test_utils import AIRFLOW_MAIN_FOLDER
from tests.test_utils.logging_command_executor import CommandExecutor
GCP_AI_KEY = "gcp_ai.json"
-GCP_AUTOML_KEY = "gcp_automl.json"
GCP_BIGQUERY_KEY = "gcp_bigquery.json"
GCP_BIGTABLE_KEY = "gcp_bigtable.json"
GCP_CLOUD_BUILD_KEY = "gcp_cloud_build.json"
diff --git
a/tests/system/providers/google/cloud/automl/example_automl_dataset.py
b/tests/system/providers/google/cloud/automl/example_automl_dataset.py
new file mode 100644
index 0000000000..909da1593a
--- /dev/null
+++ b/tests/system/providers/google/cloud/automl/example_automl_dataset.py
@@ -0,0 +1,201 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Example Airflow DAG for Google AutoML service testing dataset operations.
+"""
+from __future__ import annotations
+
+import os
+from copy import deepcopy
+from datetime import datetime
+
+from airflow import models
+from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
+from airflow.providers.google.cloud.operators.automl import (
+ AutoMLCreateDatasetOperator,
+ AutoMLDeleteDatasetOperator,
+ AutoMLImportDataOperator,
+ AutoMLListDatasetOperator,
+ AutoMLTablesListColumnSpecsOperator,
+ AutoMLTablesListTableSpecsOperator,
+ AutoMLTablesUpdateDatasetOperator,
+)
+from airflow.providers.google.cloud.operators.gcs import (
+ GCSCreateBucketOperator,
+ GCSDeleteBucketOperator,
+ GCSSynchronizeBucketsOperator,
+)
+from airflow.utils.trigger_rule import TriggerRule
+
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
+DAG_ID = "automl_dataset"
+GCP_PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default")
+
+GCP_AUTOML_LOCATION = "us-central1"
+
+DATA_SAMPLE_GCS_BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}"
+RESOURCE_DATA_BUCKET = "system-tests-resources"
+
+DATASET_NAME = f"ds_{DAG_ID}_{ENV_ID}"
+DATASET = {
+ "display_name": DATASET_NAME,
+ "tables_dataset_metadata": {"target_column_spec_id": ""},
+}
+AUTOML_DATASET_BUCKET =
f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/automl/bank-marketing.csv"
+IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [AUTOML_DATASET_BUCKET]}}
+
+extract_object_id = CloudAutoMLHook.extract_object_id
+
+
+def get_target_column_spec(columns_specs: list[dict], column_name: str) -> str:
+ """
+ Using column name returns spec of the column.
+ """
+ for column in columns_specs:
+ if column["display_name"] == column_name:
+ return extract_object_id(column)
+ raise Exception(f"Unknown target column: {column_name}")
+
+
+with models.DAG(
+ dag_id=DAG_ID,
+ schedule="@once",
+ start_date=datetime(2021, 1, 1),
+ catchup=False,
+ tags=["example", "automl"],
+ user_defined_macros={
+ "get_target_column_spec": get_target_column_spec,
+ "target": "Class",
+ "extract_object_id": extract_object_id,
+ },
+) as dag:
+ create_bucket = GCSCreateBucketOperator(
+ task_id="create_bucket",
+ bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME,
+ storage_class="REGIONAL",
+ location=GCP_AUTOML_LOCATION,
+ )
+
+ move_dataset_file = GCSSynchronizeBucketsOperator(
+ task_id="move_dataset_to_bucket",
+ source_bucket=RESOURCE_DATA_BUCKET,
+ source_object="automl",
+ destination_bucket=DATA_SAMPLE_GCS_BUCKET_NAME,
+ destination_object="automl",
+ recursive=True,
+ )
+
+ # [START howto_operator_automl_create_dataset]
+ create_dataset = AutoMLCreateDatasetOperator(
+ task_id="create_dataset",
+ dataset=DATASET,
+ location=GCP_AUTOML_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ )
+ dataset_id = create_dataset.output["dataset_id"]
+ # [END howto_operator_automl_create_dataset]
+
+ # [START howto_operator_automl_import_data]
+ import_dataset_task = AutoMLImportDataOperator(
+ task_id="import_dataset_task",
+ dataset_id=dataset_id,
+ location=GCP_AUTOML_LOCATION,
+ input_config=IMPORT_INPUT_CONFIG,
+ )
+ # [END howto_operator_automl_import_data]
+
+ # [START howto_operator_automl_specs]
+ list_tables_spec_task = AutoMLTablesListTableSpecsOperator(
+ task_id="list_tables_spec_task",
+ dataset_id=dataset_id,
+ location=GCP_AUTOML_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ )
+ # [END howto_operator_automl_specs]
+
+ # [START howto_operator_automl_column_specs]
+ list_columns_spec_task = AutoMLTablesListColumnSpecsOperator(
+ task_id="list_columns_spec_task",
+ dataset_id=dataset_id,
+ table_spec_id="{{
extract_object_id(task_instance.xcom_pull('list_tables_spec_task')[0]) }}",
+ location=GCP_AUTOML_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ )
+ # [END howto_operator_automl_column_specs]
+
+ # [START howto_operator_automl_update_dataset]
+ update = deepcopy(DATASET)
+ update["name"] = '{{ task_instance.xcom_pull("create_dataset")["name"] }}'
+ update["tables_dataset_metadata"][ # type: ignore
+ "target_column_spec_id"
+ ] = "{{
get_target_column_spec(task_instance.xcom_pull('list_columns_spec_task'),
target) }}"
+
+ update_dataset_task = AutoMLTablesUpdateDatasetOperator(
+ task_id="update_dataset_task",
+ dataset=update,
+ location=GCP_AUTOML_LOCATION,
+ )
+ # [END howto_operator_automl_update_dataset]
+
+ # [START howto_operator_list_dataset]
+ list_datasets_task = AutoMLListDatasetOperator(
+ task_id="list_datasets_task",
+ location=GCP_AUTOML_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ )
+ # [END howto_operator_list_dataset]
+
+ # [START howto_operator_delete_dataset]
+ delete_dataset = AutoMLDeleteDatasetOperator(
+ task_id="delete_dataset",
+ dataset_id="{{ task_instance.xcom_pull('list_datasets_task',
key='dataset_id_list') | list }}",
+ location=GCP_AUTOML_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ )
+ # [END howto_operator_delete_dataset]
+
+ delete_bucket = GCSDeleteBucketOperator(
+ task_id="delete_bucket", bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME,
trigger_rule=TriggerRule.ALL_DONE
+ )
+
+ (
+ # TEST SETUP
+ [create_bucket >> move_dataset_file, create_dataset]
+ # TEST BODY
+ >> import_dataset_task
+ >> list_tables_spec_task
+ >> list_columns_spec_task
+ >> update_dataset_task
+ >> list_datasets_task
+ # TEST TEARDOWN
+ >> delete_dataset
+ >> delete_bucket
+ )
+
+ from tests.system.utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "tearDown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)
diff --git a/tests/system/providers/google/cloud/automl/example_automl_model.py
b/tests/system/providers/google/cloud/automl/example_automl_model.py
new file mode 100644
index 0000000000..f08b46b024
--- /dev/null
+++ b/tests/system/providers/google/cloud/automl/example_automl_model.py
@@ -0,0 +1,285 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Example Airflow DAG for Google AutoML service testing model operations.
+"""
+from __future__ import annotations
+
+import os
+from copy import deepcopy
+from datetime import datetime
+
+from google.protobuf.struct_pb2 import Value
+
+from airflow import models
+from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
+from airflow.providers.google.cloud.operators.automl import (
+ AutoMLBatchPredictOperator,
+ AutoMLCreateDatasetOperator,
+ AutoMLDeleteDatasetOperator,
+ AutoMLDeleteModelOperator,
+ AutoMLDeployModelOperator,
+ AutoMLGetModelOperator,
+ AutoMLImportDataOperator,
+ AutoMLPredictOperator,
+ AutoMLTablesListColumnSpecsOperator,
+ AutoMLTablesListTableSpecsOperator,
+ AutoMLTablesUpdateDatasetOperator,
+ AutoMLTrainModelOperator,
+)
+from airflow.providers.google.cloud.operators.gcs import (
+ GCSCreateBucketOperator,
+ GCSDeleteBucketOperator,
+ GCSSynchronizeBucketsOperator,
+)
+from airflow.utils.trigger_rule import TriggerRule
+
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
+DAG_ID = "automl_model"
+GCP_PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default")
+
+GCP_AUTOML_LOCATION = "us-central1"
+
+DATA_SAMPLE_GCS_BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}"
+RESOURCE_DATA_BUCKET = "system-tests-resources"
+
+DATASET_NAME = f"dataset_{DAG_ID}_{ENV_ID}"
+DATASET = {
+ "display_name": DATASET_NAME,
+ "tables_dataset_metadata": {"target_column_spec_id": ""},
+}
+AUTOML_DATASET_BUCKET =
f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/automl/bank-marketing.csv"
+IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [AUTOML_DATASET_BUCKET]}}
+IMPORT_OUTPUT_CONFIG = {
+ "gcs_destination": {"output_uri_prefix":
f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/automl"}
+}
+
+MODEL_NAME = f"model_{DAG_ID}_{ENV_ID}"
+MODEL = {
+ "display_name": MODEL_NAME,
+ "tables_model_metadata": {"train_budget_milli_node_hours": 1000},
+}
+
+PREDICT_VALUES = [
+ Value(string_value="51"),
+ Value(string_value="blue-collar"),
+ Value(string_value="married"),
+ Value(string_value="primary"),
+ Value(string_value="no"),
+ Value(string_value="620"),
+ Value(string_value="yes"),
+ Value(string_value="yes"),
+ Value(string_value="cellular"),
+ Value(string_value="29"),
+ Value(string_value="jul"),
+ Value(string_value="88"),
+ Value(string_value="10"),
+ Value(string_value="-1"),
+ Value(string_value="0"),
+ Value(string_value="unknown"),
+]
+
+extract_object_id = CloudAutoMLHook.extract_object_id
+
+
+def get_target_column_spec(columns_specs: list[dict], column_name: str) -> str:
+ """
+ Using column name returns spec of the column.
+ """
+ for column in columns_specs:
+ if column["display_name"] == column_name:
+ return extract_object_id(column)
+ raise Exception(f"Unknown target column: {column_name}")
+
+
+with models.DAG(
+ dag_id=DAG_ID,
+ schedule="@once",
+ start_date=datetime(2021, 1, 1),
+ catchup=False,
+ user_defined_macros={
+ "get_target_column_spec": get_target_column_spec,
+ "target": "Class",
+ "extract_object_id": extract_object_id,
+ },
+ tags=["example", "automl"],
+) as dag:
+ create_bucket = GCSCreateBucketOperator(
+ task_id="create_bucket",
+ bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME,
+ storage_class="REGIONAL",
+ location=GCP_AUTOML_LOCATION,
+ )
+
+ move_dataset_file = GCSSynchronizeBucketsOperator(
+ task_id="move_data_to_bucket",
+ source_bucket=RESOURCE_DATA_BUCKET,
+ source_object="automl",
+ destination_bucket=DATA_SAMPLE_GCS_BUCKET_NAME,
+ destination_object="automl",
+ recursive=True,
+ )
+
+ create_dataset = AutoMLCreateDatasetOperator(
+ task_id="create_dataset",
+ dataset=DATASET,
+ location=GCP_AUTOML_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ )
+
+ dataset_id = create_dataset.output["dataset_id"]
+ MODEL["dataset_id"] = dataset_id
+ import_dataset = AutoMLImportDataOperator(
+ task_id="import_dataset",
+ dataset_id=dataset_id,
+ location=GCP_AUTOML_LOCATION,
+ input_config=IMPORT_INPUT_CONFIG,
+ )
+
+ list_tables_spec = AutoMLTablesListTableSpecsOperator(
+ task_id="list_tables_spec",
+ dataset_id=dataset_id,
+ location=GCP_AUTOML_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ )
+
+ list_columns_spec = AutoMLTablesListColumnSpecsOperator(
+ task_id="list_columns_spec",
+ dataset_id=dataset_id,
+ table_spec_id="{{
extract_object_id(task_instance.xcom_pull('list_tables_spec')[0]) }}",
+ location=GCP_AUTOML_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ )
+
+ update = deepcopy(DATASET)
+ update["name"] = '{{ task_instance.xcom_pull("create_dataset")["name"] }}'
+ update["tables_dataset_metadata"][ # type: ignore
+ "target_column_spec_id"
+ ] = "{{
get_target_column_spec(task_instance.xcom_pull('list_columns_spec'), target) }}"
+
+ update_dataset = AutoMLTablesUpdateDatasetOperator(
+ task_id="update_dataset",
+ dataset=update,
+ location=GCP_AUTOML_LOCATION,
+ )
+
+ # [START howto_operator_automl_create_model]
+ create_model = AutoMLTrainModelOperator(
+ task_id="create_model",
+ model=MODEL,
+ location=GCP_AUTOML_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ )
+ model_id = create_model.output["model_id"]
+ # [END howto_operator_automl_create_model]
+
+ # [START howto_operator_get_model]
+ get_model = AutoMLGetModelOperator(
+ task_id="get_model",
+ model_id=model_id,
+ location=GCP_AUTOML_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ )
+ # [END howto_operator_get_model]
+
+ # [START howto_operator_deploy_model]
+ deploy_model = AutoMLDeployModelOperator(
+ task_id="deploy_model",
+ model_id=model_id,
+ location=GCP_AUTOML_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ )
+ # [END howto_operator_deploy_model]
+
+ # [START howto_operator_prediction]
+ predict_task = AutoMLPredictOperator(
+ task_id="predict_task",
+ model_id=model_id,
+ payload={
+ "row": {
+ "values": PREDICT_VALUES,
+ }
+ },
+ location=GCP_AUTOML_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ )
+ # [END howto_operator_prediction]
+
+ # [START howto_operator_batch_prediction]
+ batch_predict_task = AutoMLBatchPredictOperator(
+ task_id="batch_predict_task",
+ model_id=model_id,
+ input_config=IMPORT_INPUT_CONFIG,
+ output_config=IMPORT_OUTPUT_CONFIG,
+ location=GCP_AUTOML_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ )
+ # [END howto_operator_batch_prediction]
+
+ # [START howto_operator_automl_delete_model]
+ delete_model = AutoMLDeleteModelOperator(
+ task_id="delete_model",
+ model_id=model_id,
+ location=GCP_AUTOML_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ )
+ # [END howto_operator_automl_delete_model]
+
+ delete_dataset = AutoMLDeleteDatasetOperator(
+ task_id="delete_dataset",
+ dataset_id=dataset_id,
+ location=GCP_AUTOML_LOCATION,
+ project_id=GCP_PROJECT_ID,
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+
+ delete_bucket = GCSDeleteBucketOperator(
+ task_id="delete_bucket", bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME,
trigger_rule=TriggerRule.ALL_DONE
+ )
+
+ (
+ # TEST SETUP
+ [create_bucket >> move_dataset_file, create_dataset]
+ >> import_dataset
+ >> list_tables_spec
+ >> list_columns_spec
+ >> update_dataset
+ # TEST BODY
+ >> create_model
+ >> get_model
+ >> deploy_model
+ >> predict_task
+ >> batch_predict_task
+ # TEST TEARDOWN
+ >> delete_model
+ >> delete_dataset
+ >> delete_bucket
+ )
+
+ from tests.system.utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "tearDown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)
diff --git a/tests/system/providers/google/cloud/automl/resources/__init__.py
b/tests/system/providers/google/cloud/automl/resources/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/system/providers/google/cloud/automl/resources/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.