This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 25d463c3e3 Deprecate AutoMLTrainModelOperator for NL (#34212)
25d463c3e3 is described below

commit 25d463c3e33f8628e1bcbe4dc6924693ec141dc0
Author: VladaZakharova <[email protected]>
AuthorDate: Mon Sep 11 12:53:08 2023 +0200

    Deprecate AutoMLTrainModelOperator for NL (#34212)
---
 airflow/providers/google/cloud/operators/automl.py |  21 +++-
 .../operators/cloud/automl.rst                     |  14 ++-
 .../example_automl_nl_text_classification.py       | 121 ++++++++++++---------
 .../automl/example_automl_nl_text_extraction.py    | 117 +++++++++++---------
 .../automl/example_automl_nl_text_sentiment.py     | 115 ++++++++++++--------
 5 files changed, 240 insertions(+), 148 deletions(-)

diff --git a/airflow/providers/google/cloud/operators/automl.py 
b/airflow/providers/google/cloud/operators/automl.py
index 6fefa3081a..aee9fc9631 100644
--- a/airflow/providers/google/cloud/operators/automl.py
+++ b/airflow/providers/google/cloud/operators/automl.py
@@ -19,6 +19,7 @@
 from __future__ import annotations
 
 import ast
+import warnings
 from typing import TYPE_CHECKING, Sequence, Tuple
 
 from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
@@ -31,6 +32,7 @@ from google.cloud.automl_v1beta1 import (
     TableSpec,
 )
 
+from airflow.exceptions import AirflowProviderDeprecationWarning
 from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
 from airflow.providers.google.cloud.links.automl import (
     AutoMLDatasetLink,
@@ -53,6 +55,10 @@ class AutoMLTrainModelOperator(GoogleCloudBaseOperator):
     """
     Creates Google Cloud AutoML model.
 
+    AutoMLTrainModelOperator for text prediction is deprecated. Please use
+    
:class:`airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLTextTrainingJobOperator`
+    instead.
+
     .. seealso::
         For more information on how to use this operator, take a look at the 
guide:
         :ref:`howto/operator:AutoMLTrainModelOperator`
@@ -102,7 +108,6 @@ class AutoMLTrainModelOperator(GoogleCloudBaseOperator):
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
-
         self.model = model
         self.location = location
         self.project_id = project_id
@@ -113,6 +118,20 @@ class AutoMLTrainModelOperator(GoogleCloudBaseOperator):
         self.impersonation_chain = impersonation_chain
 
     def execute(self, context: Context):
+        # Output warning if running AutoML Natural Language prediction job
+        automl_nl_model_keys = [
+            "text_classification_model_metadata",
+            "text_extraction_model_metadata",
+            "text_sentiment_dataset_metadata",
+        ]
+        if any(key in automl_nl_model_keys for key in self.model):
+            warnings.warn(
+                "AutoMLTrainModelOperator for text prediction is deprecated. 
All the functionality of legacy "
+                "AutoML Natural Language and new features are available on the 
Vertex AI platform. "
+                "Please use `CreateAutoMLTextTrainingJobOperator`",
+                AirflowProviderDeprecationWarning,
+                stacklevel=2,
+            )
         hook = CloudAutoMLHook(
             gcp_conn_id=self.gcp_conn_id,
             impersonation_chain=self.impersonation_chain,
diff --git a/docs/apache-airflow-providers-google/operators/cloud/automl.rst 
b/docs/apache-airflow-providers-google/operators/cloud/automl.rst
index 28821c4c6f..b283f51c07 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/automl.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/automl.rst
@@ -102,6 +102,16 @@ To create a Google AutoML model you can use
 The operator will wait for the operation to complete. Additionally the operator
 returns the id of model in :ref:`XCom <concepts:xcom>` under ``model_id`` key.
 
+This Operator is deprecated when running for text prediction and will be 
removed soon.
+All the functionality of legacy AutoML Natural Language and new features are 
available on the
+Vertex AI platform. Please use
+:class:`~airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLTextTrainingJobOperator`.
+When running Vertex AI Operator for training dat, please ensure that your data 
is correctly stored in Vertex AI
+datasets. To create and import data to the dataset please use
+:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.CreateDatasetOperator`
+and
+:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.ImportDataOperator`
+
 .. exampleinclude:: 
/../../tests/system/providers/google/cloud/automl/example_automl_model.py
     :language: python
     :dedent: 4
@@ -164,7 +174,7 @@ the model must be deployed.
 Listing And Deleting Datasets
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
-You can get a list of AutoML models using
+You can get a list of AutoML datasets using
 
:class:`~airflow.providers.google.cloud.operators.automl.AutoMLListDatasetOperator`.
 The operator returns list
 of datasets ids in :ref:`XCom <concepts:xcom>` under ``dataset_id_list`` key.
 
@@ -174,7 +184,7 @@ of datasets ids in :ref:`XCom <concepts:xcom>` under 
``dataset_id_list`` key.
     :start-after: [START howto_operator_list_dataset]
     :end-before: [END howto_operator_list_dataset]
 
-To delete a model you can use 
:class:`~airflow.providers.google.cloud.operators.automl.AutoMLDeleteDatasetOperator`.
+To delete a dataset you can use 
:class:`~airflow.providers.google.cloud.operators.automl.AutoMLDeleteDatasetOperator`.
 The delete operator allows also to pass list or coma separated string of 
datasets ids to be deleted.
 
 .. exampleinclude:: 
/../../tests/system/providers/google/cloud/automl/example_automl_dataset.py
diff --git 
a/tests/system/providers/google/cloud/automl/example_automl_nl_text_classification.py
 
b/tests/system/providers/google/cloud/automl/example_automl_nl_text_classification.py
index 0a04b3b361..753c91dfd0 100644
--- 
a/tests/system/providers/google/cloud/automl/example_automl_nl_text_classification.py
+++ 
b/tests/system/providers/google/cloud/automl/example_automl_nl_text_classification.py
@@ -24,47 +24,54 @@ import os
 from datetime import datetime
 from typing import cast
 
+from google.cloud.aiplatform import schema
+from google.protobuf.struct_pb2 import Value
+
 from airflow import models
 from airflow.models.xcom_arg import XComArg
 from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
-from airflow.providers.google.cloud.operators.automl import (
-    AutoMLCreateDatasetOperator,
-    AutoMLDeleteDatasetOperator,
-    AutoMLDeleteModelOperator,
-    AutoMLDeployModelOperator,
-    AutoMLImportDataOperator,
-    AutoMLTrainModelOperator,
-)
 from airflow.providers.google.cloud.operators.gcs import (
     GCSCreateBucketOperator,
     GCSDeleteBucketOperator,
     GCSSynchronizeBucketsOperator,
 )
+from airflow.providers.google.cloud.operators.vertex_ai.auto_ml import (
+    CreateAutoMLTextTrainingJobOperator,
+    DeleteAutoMLTrainingJobOperator,
+)
+from airflow.providers.google.cloud.operators.vertex_ai.dataset import (
+    CreateDatasetOperator,
+    DeleteDatasetOperator,
+    ImportDataOperator,
+)
 from airflow.utils.trigger_rule import TriggerRule
 
 ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default")
-DAG_ID = "example_automl_text_cls"
 GCP_PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default")
+DAG_ID = "example_automl_text_cls"
 
 GCP_AUTOML_LOCATION = "us-central1"
 DATA_SAMPLE_GCS_BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}".replace("_", "-")
 RESOURCE_DATA_BUCKET = "airflow-system-tests-resources"
 
-MODEL_NAME = "text_clss_test_model"
-MODEL = {
-    "display_name": MODEL_NAME,
-    "text_classification_model_metadata": {},
-}
+TEXT_CLSS_DISPLAY_NAME = f"{DAG_ID}-{ENV_ID}".replace("_", "-")
+AUTOML_DATASET_BUCKET = 
f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/automl/classification.csv"
+
+MODEL_NAME = f"{DAG_ID}-{ENV_ID}".replace("_", "-")
 
 DATASET_NAME = f"ds_clss_{ENV_ID}".replace("-", "_")
 DATASET = {
     "display_name": DATASET_NAME,
-    "text_classification_dataset_metadata": {"classification_type": 
"MULTICLASS"},
+    "metadata_schema_uri": schema.dataset.metadata.text,
+    "metadata": Value(string_value="clss-dataset"),
 }
 
-AUTOML_DATASET_BUCKET = 
f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/automl/text_classification.csv"
-IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [AUTOML_DATASET_BUCKET]}}
-
+DATA_CONFIG = [
+    {
+        "import_schema_uri": 
schema.dataset.ioformat.text.single_label_classification,
+        "gcs_source": {"uris": [AUTOML_DATASET_BUCKET]},
+    },
+]
 extract_object_id = CloudAutoMLHook.extract_object_id
 
 # Example DAG for AutoML Natural Language Text Classification
@@ -85,67 +92,77 @@ with models.DAG(
     move_dataset_file = GCSSynchronizeBucketsOperator(
         task_id="move_dataset_to_bucket",
         source_bucket=RESOURCE_DATA_BUCKET,
-        source_object="automl/datasets/text",
+        source_object="vertex-ai/automl/datasets/text",
         destination_bucket=DATA_SAMPLE_GCS_BUCKET_NAME,
         destination_object="automl",
         recursive=True,
     )
 
-    create_dataset = AutoMLCreateDatasetOperator(
-        task_id="create_dataset",
+    create_clss_dataset = CreateDatasetOperator(
+        task_id="create_clss_dataset",
         dataset=DATASET,
-        location=GCP_AUTOML_LOCATION,
+        region=GCP_AUTOML_LOCATION,
         project_id=GCP_PROJECT_ID,
     )
+    clss_dataset_id = create_clss_dataset.output["dataset_id"]
 
-    dataset_id = cast(str, XComArg(create_dataset, key="dataset_id"))
-    MODEL["dataset_id"] = dataset_id
-    import_dataset = AutoMLImportDataOperator(
-        task_id="import_dataset",
-        dataset_id=dataset_id,
-        location=GCP_AUTOML_LOCATION,
-        input_config=IMPORT_INPUT_CONFIG,
+    import_clss_dataset = ImportDataOperator(
+        task_id="import_clss_data",
+        dataset_id=clss_dataset_id,
+        region=GCP_AUTOML_LOCATION,
+        project_id=GCP_PROJECT_ID,
+        import_configs=DATA_CONFIG,
     )
-    MODEL["dataset_id"] = dataset_id
-
-    create_model = AutoMLTrainModelOperator(task_id="create_model", 
model=MODEL, location=GCP_AUTOML_LOCATION)
-    model_id = cast(str, XComArg(create_model, key="model_id"))
 
-    deploy_model = AutoMLDeployModelOperator(
-        task_id="deploy_model",
-        model_id=model_id,
-        location=GCP_AUTOML_LOCATION,
+    # [START howto_operator_automl_create_model]
+    create_clss_training_job = CreateAutoMLTextTrainingJobOperator(
+        task_id="create_clss_training_job",
+        display_name=TEXT_CLSS_DISPLAY_NAME,
+        prediction_type="classification",
+        multi_label=False,
+        dataset_id=clss_dataset_id,
+        model_display_name=MODEL_NAME,
+        training_fraction_split=0.7,
+        validation_fraction_split=0.2,
+        test_fraction_split=0.1,
+        sync=True,
+        region=GCP_AUTOML_LOCATION,
         project_id=GCP_PROJECT_ID,
     )
+    # [END howto_operator_automl_create_model]
+    model_id = cast(str, XComArg(create_clss_training_job, key="model_id"))
 
-    delete_model = AutoMLDeleteModelOperator(
-        task_id="delete_model",
-        model_id=model_id,
-        location=GCP_AUTOML_LOCATION,
+    delete_clss_training_job = DeleteAutoMLTrainingJobOperator(
+        task_id="delete_clss_training_job",
+        training_pipeline_id=create_clss_training_job.output["training_id"],
+        region=GCP_AUTOML_LOCATION,
         project_id=GCP_PROJECT_ID,
+        trigger_rule=TriggerRule.ALL_DONE,
     )
 
-    delete_dataset = AutoMLDeleteDatasetOperator(
-        task_id="delete_dataset",
-        dataset_id=dataset_id,
-        location=GCP_AUTOML_LOCATION,
+    delete_clss_dataset = DeleteDatasetOperator(
+        task_id="delete_clss_dataset",
+        dataset_id=clss_dataset_id,
+        region=GCP_AUTOML_LOCATION,
         project_id=GCP_PROJECT_ID,
+        trigger_rule=TriggerRule.ALL_DONE,
     )
 
     delete_bucket = GCSDeleteBucketOperator(
-        task_id="delete_bucket", bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME, 
trigger_rule=TriggerRule.ALL_DONE
+        task_id="delete_bucket",
+        bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME,
+        trigger_rule=TriggerRule.ALL_DONE,
     )
 
     (
         # TEST SETUP
-        [create_bucket >> move_dataset_file, create_dataset]
+        [create_bucket >> move_dataset_file, create_clss_dataset]
         # TEST BODY
-        >> import_dataset
-        >> create_model
-        >> deploy_model
+        >> import_clss_dataset
+        >> create_clss_training_job
         # TEST TEARDOWN
-        >> delete_model
-        >> delete_dataset
+        >> delete_clss_training_job
+        >> delete_clss_dataset
         >> delete_bucket
     )
 
diff --git 
a/tests/system/providers/google/cloud/automl/example_automl_nl_text_extraction.py
 
b/tests/system/providers/google/cloud/automl/example_automl_nl_text_extraction.py
index 260a7d84c6..06b22779a8 100644
--- 
a/tests/system/providers/google/cloud/automl/example_automl_nl_text_extraction.py
+++ 
b/tests/system/providers/google/cloud/automl/example_automl_nl_text_extraction.py
@@ -24,42 +24,52 @@ import os
 from datetime import datetime
 from typing import cast
 
+from google.cloud.aiplatform import schema
+from google.protobuf.struct_pb2 import Value
+
 from airflow import models
 from airflow.models.xcom_arg import XComArg
 from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
-from airflow.providers.google.cloud.operators.automl import (
-    AutoMLCreateDatasetOperator,
-    AutoMLDeleteDatasetOperator,
-    AutoMLDeleteModelOperator,
-    AutoMLImportDataOperator,
-    AutoMLTrainModelOperator,
-)
 from airflow.providers.google.cloud.operators.gcs import (
     GCSCreateBucketOperator,
     GCSDeleteBucketOperator,
     GCSSynchronizeBucketsOperator,
 )
+from airflow.providers.google.cloud.operators.vertex_ai.auto_ml import (
+    CreateAutoMLTextTrainingJobOperator,
+    DeleteAutoMLTrainingJobOperator,
+)
+from airflow.providers.google.cloud.operators.vertex_ai.dataset import (
+    CreateDatasetOperator,
+    DeleteDatasetOperator,
+    ImportDataOperator,
+)
 from airflow.utils.trigger_rule import TriggerRule
 
 ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default")
-DAG_ID = "example_automl_text_extr"
 GCP_PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default")
+DAG_ID = "example_automl_text_extr"
 
 GCP_AUTOML_LOCATION = "us-central1"
 RESOURCE_DATA_BUCKET = "airflow-system-tests-resources"
 
 DATA_SAMPLE_GCS_BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}".replace("_", "-")
-
-DATASET_NAME = f"ds_extr_{ENV_ID}".replace("-", "_")
-DATASET = {"display_name": DATASET_NAME, "text_extraction_dataset_metadata": 
{}}
-AUTOML_DATASET_BUCKET = 
f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/automl/text_extraction.csv"
-IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [AUTOML_DATASET_BUCKET]}}
-
-MODEL_NAME = "entity_extr_test_model"
-MODEL = {
-    "display_name": MODEL_NAME,
-    "text_extraction_model_metadata": {},
+TEXT_EXTR_DISPLAY_NAME = f"{DAG_ID}-{ENV_ID}".replace("_", "-")
+AUTOML_DATASET_BUCKET = 
f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/automl/extraction.jsonl"
+
+MODEL_NAME = f"{DAG_ID}-{ENV_ID}".replace("_", "-")
+DATASET_NAME = f"ds_clss_{ENV_ID}".replace("-", "_")
+DATASET = {
+    "display_name": DATASET_NAME,
+    "metadata_schema_uri": schema.dataset.metadata.text,
+    "metadata": Value(string_value="extr-dataset"),
 }
+DATA_CONFIG = [
+    {
+        "import_schema_uri": schema.dataset.ioformat.text.extraction,
+        "gcs_source": {"uris": [AUTOML_DATASET_BUCKET]},
+    },
+]
 
 extract_object_id = CloudAutoMLHook.extract_object_id
 
@@ -80,51 +90,60 @@ with models.DAG(
     )
 
     move_dataset_file = GCSSynchronizeBucketsOperator(
-        task_id="move_data_to_bucket",
+        task_id="move_dataset_to_bucket",
         source_bucket=RESOURCE_DATA_BUCKET,
-        source_object="automl/datasets/text",
+        source_object="vertex-ai/automl/datasets/text",
         destination_bucket=DATA_SAMPLE_GCS_BUCKET_NAME,
         destination_object="automl",
         recursive=True,
     )
 
-    create_dataset = AutoMLCreateDatasetOperator(
-        task_id="create_dataset",
+    create_extr_dataset = CreateDatasetOperator(
+        task_id="create_extr_dataset",
         dataset=DATASET,
-        location=GCP_AUTOML_LOCATION,
+        region=GCP_AUTOML_LOCATION,
         project_id=GCP_PROJECT_ID,
     )
+    extr_dataset_id = create_extr_dataset.output["dataset_id"]
 
-    dataset_id = cast(str, XComArg(create_dataset, key="dataset_id"))
-    MODEL["dataset_id"] = dataset_id
-    import_dataset = AutoMLImportDataOperator(
-        task_id="import_dataset",
-        dataset_id=dataset_id,
-        location=GCP_AUTOML_LOCATION,
-        input_config=IMPORT_INPUT_CONFIG,
+    import_extr_dataset = ImportDataOperator(
+        task_id="import_extr_data",
+        dataset_id=extr_dataset_id,
+        region=GCP_AUTOML_LOCATION,
         project_id=GCP_PROJECT_ID,
+        import_configs=DATA_CONFIG,
     )
-    MODEL["dataset_id"] = dataset_id
 
-    create_model = AutoMLTrainModelOperator(
-        task_id="create_model",
-        model=MODEL,
-        location=GCP_AUTOML_LOCATION,
+    # [START howto_operator_automl_create_model]
+    create_extr_training_job = CreateAutoMLTextTrainingJobOperator(
+        task_id="create_extr_training_job",
+        display_name=TEXT_EXTR_DISPLAY_NAME,
+        prediction_type="extraction",
+        multi_label=False,
+        dataset_id=extr_dataset_id,
+        model_display_name=MODEL_NAME,
+        training_fraction_split=0.8,
+        validation_fraction_split=0.1,
+        test_fraction_split=0.1,
+        sync=True,
+        region=GCP_AUTOML_LOCATION,
         project_id=GCP_PROJECT_ID,
     )
-    model_id = cast(str, XComArg(create_model, key="model_id"))
+    # [END howto_operator_automl_create_model]
+    model_id = cast(str, XComArg(create_extr_training_job, key="model_id"))
 
-    delete_model_task = AutoMLDeleteModelOperator(
-        task_id="delete_model_task",
-        model_id=model_id,
-        location=GCP_AUTOML_LOCATION,
+    delete_extr_training_job = DeleteAutoMLTrainingJobOperator(
+        task_id="delete_extr_training_job",
+        training_pipeline_id=create_extr_training_job.output["training_id"],
+        region=GCP_AUTOML_LOCATION,
         project_id=GCP_PROJECT_ID,
+        trigger_rule=TriggerRule.ALL_DONE,
     )
 
-    delete_datasets_task = AutoMLDeleteDatasetOperator(
-        task_id="delete_datasets_task",
-        dataset_id=dataset_id,
-        location=GCP_AUTOML_LOCATION,
+    delete_extr_dataset = DeleteDatasetOperator(
+        task_id="delete_extr_dataset",
+        dataset_id=extr_dataset_id,
+        region=GCP_AUTOML_LOCATION,
         project_id=GCP_PROJECT_ID,
         trigger_rule=TriggerRule.ALL_DONE,
     )
@@ -137,13 +156,13 @@ with models.DAG(
 
     (
         # TEST SETUP
-        [create_bucket >> move_dataset_file, create_dataset]
+        [create_bucket >> move_dataset_file, create_extr_dataset]
         # TEST BODY
-        >> import_dataset
-        >> create_model
+        >> import_extr_dataset
+        >> create_extr_training_job
         # TEST TEARDOWN
-        >> delete_model_task
-        >> delete_datasets_task
+        >> delete_extr_training_job
+        >> delete_extr_dataset
         >> delete_bucket
     )
 
diff --git 
a/tests/system/providers/google/cloud/automl/example_automl_nl_text_sentiment.py
 
b/tests/system/providers/google/cloud/automl/example_automl_nl_text_sentiment.py
index 3559339755..1529f07bc6 100644
--- 
a/tests/system/providers/google/cloud/automl/example_automl_nl_text_sentiment.py
+++ 
b/tests/system/providers/google/cloud/automl/example_automl_nl_text_sentiment.py
@@ -24,44 +24,53 @@ import os
 from datetime import datetime
 from typing import cast
 
+from google.cloud.aiplatform import schema
+from google.protobuf.struct_pb2 import Value
+
 from airflow import models
 from airflow.models.xcom_arg import XComArg
 from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
-from airflow.providers.google.cloud.operators.automl import (
-    AutoMLCreateDatasetOperator,
-    AutoMLDeleteDatasetOperator,
-    AutoMLDeleteModelOperator,
-    AutoMLImportDataOperator,
-    AutoMLTrainModelOperator,
-)
 from airflow.providers.google.cloud.operators.gcs import (
     GCSCreateBucketOperator,
     GCSDeleteBucketOperator,
     GCSSynchronizeBucketsOperator,
 )
+from airflow.providers.google.cloud.operators.vertex_ai.auto_ml import (
+    CreateAutoMLTextTrainingJobOperator,
+    DeleteAutoMLTrainingJobOperator,
+)
+from airflow.providers.google.cloud.operators.vertex_ai.dataset import (
+    CreateDatasetOperator,
+    DeleteDatasetOperator,
+    ImportDataOperator,
+)
 from airflow.utils.trigger_rule import TriggerRule
 
 ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default")
-DAG_ID = "example_automl_text_sent"
 GCP_PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default")
+DAG_ID = "example_automl_text_sent"
 GCP_AUTOML_LOCATION = "us-central1"
 DATA_SAMPLE_GCS_BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}".replace("_", "-")
 RESOURCE_DATA_BUCKET = "airflow-system-tests-resources"
 
-MODEL_NAME = "text_sent_test_model"
-MODEL = {
-    "display_name": MODEL_NAME,
-    "text_sentiment_model_metadata": {},
-}
+TEXT_SENT_DISPLAY_NAME = f"{DAG_ID}-{ENV_ID}".replace("_", "-")
+AUTOML_DATASET_BUCKET = 
f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/automl/sentiment.csv"
+
+MODEL_NAME = f"{DAG_ID}-{ENV_ID}".replace("_", "-")
 
 DATASET_NAME = f"ds_sent_{ENV_ID}".replace("-", "_")
 DATASET = {
     "display_name": DATASET_NAME,
-    "text_sentiment_dataset_metadata": {"sentiment_max": 5},
+    "metadata_schema_uri": schema.dataset.metadata.text,
+    "metadata": Value(string_value="sent-dataset"),
 }
 
-AUTOML_DATASET_BUCKET = 
f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/automl/text_sentiment.csv"
-IMPORT_INPUT_CONFIG = {"gcs_source": {"input_uris": [AUTOML_DATASET_BUCKET]}}
+DATA_CONFIG = [
+    {
+        "import_schema_uri": schema.dataset.ioformat.text.sentiment,
+        "gcs_source": {"uris": [AUTOML_DATASET_BUCKET]},
+    },
+]
 
 extract_object_id = CloudAutoMLHook.extract_object_id
 
@@ -84,43 +93,61 @@ with models.DAG(
     move_dataset_file = GCSSynchronizeBucketsOperator(
         task_id="move_dataset_to_bucket",
         source_bucket=RESOURCE_DATA_BUCKET,
-        source_object="automl/datasets/text",
+        source_object="vertex-ai/automl/datasets/text",
         destination_bucket=DATA_SAMPLE_GCS_BUCKET_NAME,
         destination_object="automl",
         recursive=True,
     )
 
-    create_dataset = AutoMLCreateDatasetOperator(
-        task_id="create_dataset", dataset=DATASET, location=GCP_AUTOML_LOCATION
+    create_sent_dataset = CreateDatasetOperator(
+        task_id="create_sent_dataset",
+        dataset=DATASET,
+        region=GCP_AUTOML_LOCATION,
+        project_id=GCP_PROJECT_ID,
     )
+    sent_dataset_id = create_sent_dataset.output["dataset_id"]
 
-    dataset_id = cast(str, XComArg(create_dataset, key="dataset_id"))
-    MODEL["dataset_id"] = dataset_id
-
-    import_dataset = AutoMLImportDataOperator(
-        task_id="import_dataset",
-        dataset_id=dataset_id,
-        location=GCP_AUTOML_LOCATION,
-        input_config=IMPORT_INPUT_CONFIG,
+    import_sent_dataset = ImportDataOperator(
+        task_id="import_sent_data",
+        dataset_id=sent_dataset_id,
+        region=GCP_AUTOML_LOCATION,
+        project_id=GCP_PROJECT_ID,
+        import_configs=DATA_CONFIG,
     )
 
-    MODEL["dataset_id"] = dataset_id
-
-    create_model = AutoMLTrainModelOperator(task_id="create_model", 
model=MODEL, location=GCP_AUTOML_LOCATION)
-    model_id = cast(str, XComArg(create_model, key="model_id"))
+    # [START howto_operator_automl_create_model]
+    create_sent_training_job = CreateAutoMLTextTrainingJobOperator(
+        task_id="create_sent_training_job",
+        display_name=TEXT_SENT_DISPLAY_NAME,
+        prediction_type="sentiment",
+        multi_label=False,
+        dataset_id=sent_dataset_id,
+        model_display_name=MODEL_NAME,
+        training_fraction_split=0.7,
+        validation_fraction_split=0.2,
+        test_fraction_split=0.1,
+        sentiment_max=5,
+        sync=True,
+        region=GCP_AUTOML_LOCATION,
+        project_id=GCP_PROJECT_ID,
+    )
+    # [END howto_operator_automl_create_model]
+    model_id = cast(str, XComArg(create_sent_training_job, key="model_id"))
 
-    delete_model = AutoMLDeleteModelOperator(
-        task_id="delete_model",
-        model_id=model_id,
-        location=GCP_AUTOML_LOCATION,
+    delete_sent_training_job = DeleteAutoMLTrainingJobOperator(
+        task_id="delete_sent_training_job",
+        training_pipeline_id=create_sent_training_job.output["training_id"],
+        region=GCP_AUTOML_LOCATION,
         project_id=GCP_PROJECT_ID,
+        trigger_rule=TriggerRule.ALL_DONE,
     )
 
-    delete_dataset = AutoMLDeleteDatasetOperator(
-        task_id="delete_dataset",
-        dataset_id=dataset_id,
-        location=GCP_AUTOML_LOCATION,
+    delete_sent_dataset = DeleteDatasetOperator(
+        task_id="delete_sent_dataset",
+        dataset_id=sent_dataset_id,
+        region=GCP_AUTOML_LOCATION,
         project_id=GCP_PROJECT_ID,
+        trigger_rule=TriggerRule.ALL_DONE,
     )
 
     delete_bucket = GCSDeleteBucketOperator(
@@ -131,13 +158,13 @@ with models.DAG(
 
     (
         # TEST SETUP
-        [create_bucket >> move_dataset_file, create_dataset]
+        [create_bucket >> move_dataset_file, create_sent_dataset]
         # TEST BODY
-        >> import_dataset
-        >> create_model
+        >> import_sent_dataset
+        >> create_sent_training_job
         # TEST TEARDOWN
-        >> delete_model
-        >> delete_dataset
+        >> delete_sent_training_job
+        >> delete_sent_dataset
         >> delete_bucket
     )
 

Reply via email to