This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 7717985775f Introduce the translation API v3 (advanced) models 
operators. (#44627)
7717985775f is described below

commit 7717985775fa42f5722ea1b624c09ac1e2fcfa75
Author: olegkachur-e <[email protected]>
AuthorDate: Fri Dec 6 14:16:37 2024 +0100

    Introduce the translation API v3 (advanced) models operators. (#44627)
    
    - TranslateCreateModelOperator
    - TranslateModelsListOperator
    - TranslateDeleteModelOperator
    
    More details on using AutoML translation: 
https://cloud.google.com/translate/docs/advanced/automl-beginner.
    
    Co-authored-by: Oleg Kachur <[email protected]>
---
 .../operators/cloud/translate.rst                  |  63 +++++
 docs/spelling_wordlist.txt                         |   1 +
 .../providers/google/cloud/hooks/translate.py      | 154 +++++++++++++
 .../providers/google/cloud/links/translate.py      |  63 +++++
 .../providers/google/cloud/operators/translate.py  | 255 +++++++++++++++++++++
 .../src/airflow/providers/google/provider.yaml     |   2 +
 .../tests/google/cloud/operators/test_translate.py | 156 +++++++++++++
 .../cloud/translate/example_translate_model.py     | 178 ++++++++++++++
 8 files changed, 872 insertions(+)

diff --git a/docs/apache-airflow-providers-google/operators/cloud/translate.rst 
b/docs/apache-airflow-providers-google/operators/cloud/translate.rst
index d56fac26dbe..5bda3d9085a 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/translate.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/translate.rst
@@ -184,6 +184,69 @@ Basic usage of the operator:
     :end-before: [END howto_operator_translate_automl_delete_dataset]
 
 
+.. _howto/operator:TranslateCreateModelOperator:
+
+TranslateCreateModelOperator
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+Create a native translation model using Cloud Translate API (Advanced V3).
+
+For parameter definition, take a look at
+:class:`~airflow.providers.google.cloud.operators.translate.TranslateCreateModelOperator`
+
+Using the operator
+""""""""""""""""""
+
+Basic usage of the operator:
+
+.. exampleinclude:: 
/../../providers/tests/system/google/cloud/translate/example_translate_model.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_translate_automl_create_model]
+    :end-before: [END howto_operator_translate_automl_create_model]
+
+
+.. _howto/operator:TranslateModelsListOperator:
+
+TranslateModelsListOperator
+^^^^^^^^^^^^^^^^^^^^^^^^^^^
+Get list of native translation models using Cloud Translate API (Advanced V3).
+
+For parameter definition, take a look at
+:class:`~airflow.providers.google.cloud.operators.translate.TranslateModelsListOperator`
+
+Using the operator
+""""""""""""""""""
+
+Basic usage of the operator:
+
+.. exampleinclude:: 
/../../providers/tests/system/google/cloud/translate/example_translate_model.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_translate_automl_list_models]
+    :end-before: [END howto_operator_translate_automl_list_models]
+
+
+.. _howto/operator:TranslateDeleteModelOperator:
+
+TranslateDeleteModelOperator
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+Delete a native translation model using Cloud Translate API (Advanced V3).
+
+For parameter definition, take a look at
+:class:`~airflow.providers.google.cloud.operators.translate.TranslateDeleteModelOperator`
+
+Using the operator
+""""""""""""""""""
+
+Basic usage of the operator:
+
+.. exampleinclude:: 
/../../providers/tests/system/google/cloud/translate/example_translate_model.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_translate_automl_delete_model]
+    :end-before: [END howto_operator_translate_automl_delete_model]
+
+
 More information
 """"""""""""""""""
 See:
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index e02e6cee380..69656b6d086 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -970,6 +970,7 @@ linux
 ListDatasetsPager
 ListGenerator
 ListInfoTypesResponse
+ListModelsPager
 ListSecretsPager
 Liveness
 liveness
diff --git a/providers/src/airflow/providers/google/cloud/hooks/translate.py 
b/providers/src/airflow/providers/google/cloud/hooks/translate.py
index 4901e4a9914..43e0c15774b 100644
--- a/providers/src/airflow/providers/google/cloud/hooks/translate.py
+++ b/providers/src/airflow/providers/google/cloud/hooks/translate.py
@@ -560,3 +560,157 @@ class TranslateHook(GoogleBaseHook):
             metadata=metadata,
         )
         return result
+
+    def create_model(
+        self,
+        dataset_id: str,
+        display_name: str,
+        project_id: str,
+        location: str,
+        retry: Retry | _MethodDefault = DEFAULT,
+        timeout: float | None = None,
+        metadata: Sequence[tuple[str, str]] = (),
+    ) -> Operation:
+        """
+        Create the native model by training on translation dataset provided.
+
+        :param dataset_id: ID of dataset to be used for model training.
+        :param display_name: Display name of the model trained.
+            A-Z and a-z, underscores (_), and ASCII digits 0-9.
+        :param project_id: ID of the Google Cloud project where dataset is 
located. If not provided
+            default project_id is used.
+        :param location: The location of the project.
+        :param retry: A retry object used to retry requests. If `None` is 
specified, requests will not be
+            retried.
+        :param timeout: The amount of time, in seconds, to wait for the 
request to complete. Note that if
+            `retry` is specified, the timeout applies to each individual 
attempt.
+        :param metadata: Additional metadata that is provided to the method.
+
+        :return: `Operation` object with the model creation results, when 
finished.
+        """
+        client = self.get_client()
+        project_id = project_id or self.project_id
+        parent = f"projects/{project_id}/locations/{location}"
+        dataset = 
f"projects/{project_id}/locations/{location}/datasets/{dataset_id}"
+        result = client.create_model(
+            request={
+                "parent": parent,
+                "model": {
+                    "display_name": display_name,
+                    "dataset": dataset,
+                },
+            },
+            retry=retry,
+            timeout=timeout,
+            metadata=metadata,
+        )
+        return result
+
+    def get_model(
+        self,
+        model_id: str,
+        project_id: str,
+        location: str,
+        retry: Retry | _MethodDefault = DEFAULT,
+        timeout: float | _MethodDefault = DEFAULT,
+        metadata: Sequence[tuple[str, str]] = (),
+    ) -> automl_translation.Model:
+        """
+        Retrieve the dataset for the given model_id.
+
+        :param model_id: ID of translation model to be retrieved.
+        :param project_id: ID of the Google Cloud project where dataset is 
located. If not provided
+            default project_id is used.
+        :param location: The location of the project.
+        :param retry: A retry object used to retry requests. If `None` is 
specified, requests will not be
+            retried.
+        :param timeout: The amount of time, in seconds, to wait for the 
request to complete. Note that if
+            `retry` is specified, the timeout applies to each individual 
attempt.
+        :param metadata: Additional metadata that is provided to the method.
+
+        :return: `automl_translation.Model` instance.
+        """
+        client = self.get_client()
+        name = f"projects/{project_id}/locations/{location}/models/{model_id}"
+        return client.get_model(
+            request={"name": name},
+            retry=retry,
+            timeout=timeout,
+            metadata=metadata,
+        )
+
+    def list_models(
+        self,
+        project_id: str,
+        location: str,
+        filter_str: str | None = None,
+        page_size: int | None = None,
+        retry: Retry | _MethodDefault = DEFAULT,
+        timeout: float | _MethodDefault = DEFAULT,
+        metadata: Sequence[tuple[str, str]] = (),
+    ) -> pagers.ListModelsPager:
+        """
+        List translation models in a project.
+
+        :param project_id: ID of the Google Cloud project where models are 
located. If not provided
+            default project_id is used.
+        :param location: The location of the project.
+        :param filter_str: An optional expression for filtering the models 
that will
+            be returned. Supported filter: ``dataset_id=${dataset_id}``.
+        :param page_size: Optional custom page size value. The server can
+            return fewer results than requested.
+        :param retry: A retry object used to retry requests. If `None` is 
specified, requests will not be
+            retried.
+        :param timeout: The amount of time, in seconds, to wait for the 
request to complete. Note that if
+            `retry` is specified, the timeout applies to each individual 
attempt.
+        :param metadata: Additional metadata that is provided to the method.
+
+        :return: ``pagers.ListDatasetsPager`` instance, iterable object to 
retrieve the datasets list.
+        """
+        client = self.get_client()
+        parent = f"projects/{project_id}/locations/{location}"
+        result = client.list_models(
+            request={
+                "parent": parent,
+                "filter": filter_str,
+                "page_size": page_size,
+            },
+            retry=retry,
+            timeout=timeout,
+            metadata=metadata,
+        )
+        return result
+
+    def delete_model(
+        self,
+        model_id: str,
+        project_id: str,
+        location: str,
+        retry: Retry | _MethodDefault = DEFAULT,
+        timeout: float | None = None,
+        metadata: Sequence[tuple[str, str]] = (),
+    ) -> Operation:
+        """
+        Delete the translation model and all of its contents.
+
+        :param model_id: ID of model to be deleted.
+        :param project_id: ID of the Google Cloud project where dataset is 
located. If not provided
+            default project_id is used.
+        :param location: The location of the project.
+        :param retry: A retry object used to retry requests. If `None` is 
specified, requests will not be
+            retried.
+        :param timeout: The amount of time, in seconds, to wait for the 
request to complete. Note that if
+            `retry` is specified, the timeout applies to each individual 
attempt.
+        :param metadata: Additional metadata that is provided to the method.
+
+        :return: `Operation` object with dataset deletion results, when 
finished.
+        """
+        client = self.get_client()
+        name = f"projects/{project_id}/locations/{location}/models/{model_id}"
+        result = client.delete_model(
+            request={"name": name},
+            retry=retry,
+            timeout=timeout,
+            metadata=metadata,
+        )
+        return result
diff --git a/providers/src/airflow/providers/google/cloud/links/translate.py 
b/providers/src/airflow/providers/google/cloud/links/translate.py
index 0d1489ddcfc..55db2650838 100644
--- a/providers/src/airflow/providers/google/cloud/links/translate.py
+++ b/providers/src/airflow/providers/google/cloud/links/translate.py
@@ -50,6 +50,12 @@ TRANSLATION_NATIVE_DATASET_LINK = (
 )
 TRANSLATION_NATIVE_LIST_LINK = TRANSLATION_BASE_LINK + 
"/datasets?project={project_id}"
 
+TRANSLATION_NATIVE_MODEL_LINK = (
+    TRANSLATION_BASE_LINK
+    + 
"/locations/{location}/datasets/{dataset_id}/evaluate;modelId={model_id}?project={project_id}"
+)
+TRANSLATION_MODELS_LIST_LINK = TRANSLATION_BASE_LINK + 
"/models/list?project={project_id}"
+
 
 class TranslationLegacyDatasetLink(BaseGoogleLink):
     """
@@ -270,3 +276,60 @@ class TranslationDatasetsListLink(BaseGoogleLink):
                 "project_id": project_id,
             },
         )
+
+
+class TranslationModelLink(BaseGoogleLink):
+    """
+    Helper class for constructing Translation Model link.
+
+    Link for legacy and native models.
+    """
+
+    name = "Translation Model"
+    key = "translation_model"
+    format_str = TRANSLATION_NATIVE_MODEL_LINK
+
+    @staticmethod
+    def persist(
+        context: Context,
+        task_instance,
+        dataset_id: str,
+        model_id: str,
+        project_id: str,
+    ):
+        task_instance.xcom_push(
+            context,
+            key=TranslationLegacyModelLink.key,
+            value={
+                "location": task_instance.location,
+                "dataset_id": dataset_id,
+                "model_id": model_id,
+                "project_id": project_id,
+            },
+        )
+
+
+class TranslationModelsListLink(BaseGoogleLink):
+    """
+    Helper class for constructing Translation Models List link.
+
+    Both legacy and native models are available under this link.
+    """
+
+    name = "Translation Models List"
+    key = "translation_models_list"
+    format_str = TRANSLATION_MODELS_LIST_LINK
+
+    @staticmethod
+    def persist(
+        context: Context,
+        task_instance,
+        project_id: str,
+    ):
+        task_instance.xcom_push(
+            context,
+            key=TranslationModelsListLink.key,
+            value={
+                "project_id": project_id,
+            },
+        )
diff --git 
a/providers/src/airflow/providers/google/cloud/operators/translate.py 
b/providers/src/airflow/providers/google/cloud/operators/translate.py
index 193433deaa7..4c04e9a7bc5 100644
--- a/providers/src/airflow/providers/google/cloud/operators/translate.py
+++ b/providers/src/airflow/providers/google/cloud/operators/translate.py
@@ -30,6 +30,8 @@ from airflow.providers.google.cloud.hooks.translate import 
CloudTranslateHook, T
 from airflow.providers.google.cloud.links.translate import (
     TranslateTextBatchLink,
     TranslationDatasetsListLink,
+    TranslationModelLink,
+    TranslationModelsListLink,
     TranslationNativeDatasetLink,
 )
 from airflow.providers.google.cloud.operators.cloud_base import 
GoogleCloudBaseOperator
@@ -723,3 +725,256 @@ class 
TranslateDeleteDatasetOperator(GoogleCloudBaseOperator):
         )
         hook.wait_for_operation_done(operation=operation, timeout=self.timeout)
         self.log.info("Dataset deletion complete!")
+
+
+class TranslateCreateModelOperator(GoogleCloudBaseOperator):
+    """
+    Creates a Google Cloud Translate model.
+
+    Creates a `native` translation model, using API V3.
+    For more information on how to use this operator, take a look at the guide:
+    :ref:`howto/operator:TranslateCreateModelOperator`.
+
+    :param dataset_id: The dataset id used for model training.
+    :param project_id: ID of the Google Cloud project where dataset is located.
+        If not provided default project_id is used.
+    :param location: The location of the project.
+    :param retry: Designation of what errors, if any, should be retried.
+    :param timeout: The timeout for this request.
+    :param metadata:  Strings which should be sent along with the request as 
metadata.
+    :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+    :param impersonation_chain: Optional service account to impersonate using 
short-term
+        credentials, or chained list of accounts required to get the 
access_token
+        of the last account in the list, which will be impersonated in the 
request.
+        If set as a string, the account must grant the originating account
+        the Service Account Token Creator IAM role.
+        If set as a sequence, the identities from the list must grant
+        Service Account Token Creator IAM role to the directly preceding 
identity, with first
+        account from the list granting this role to the originating account 
(templated).
+    """
+
+    template_fields: Sequence[str] = (
+        "dataset_id",
+        "location",
+        "project_id",
+        "gcp_conn_id",
+        "impersonation_chain",
+    )
+
+    operator_extra_links = (TranslationModelLink(),)
+
+    def __init__(
+        self,
+        *,
+        project_id: str = PROVIDE_PROJECT_ID,
+        location: str,
+        dataset_id: str,
+        display_name: str,
+        timeout: float | None = None,
+        retry: Retry | _MethodDefault = DEFAULT,
+        gcp_conn_id: str = "google_cloud_default",
+        metadata: Sequence[tuple[str, str]] = (),
+        impersonation_chain: str | Sequence[str] | None = None,
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        self.project_id = project_id
+        self.location = location
+        self.dataset_id = dataset_id
+        self.display_name = display_name
+        self.metadata = metadata
+        self.timeout = timeout
+        self.retry = retry
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+
+    def execute(self, context: Context) -> str:
+        hook = TranslateHook(
+            gcp_conn_id=self.gcp_conn_id,
+            impersonation_chain=self.impersonation_chain,
+        )
+        self.log.info("Model creation started, dataset_id %s...", 
self.dataset_id)
+        try:
+            result_operation = hook.create_model(
+                dataset_id=self.dataset_id,
+                display_name=self.display_name,
+                location=self.location,
+                project_id=self.project_id,
+                retry=self.retry,
+                timeout=self.timeout,
+                metadata=self.metadata,
+            )
+        except GoogleAPICallError as e:
+            self.log.error("Error submitting create_model operation ")
+            raise AirflowException(e)
+
+        self.log.info("Training has started")
+        hook.wait_for_operation_done(operation=result_operation)
+        result = hook.wait_for_operation_result(operation=result_operation)
+        result = type(result).to_dict(result)
+        model_id = hook.extract_object_id(result)
+        self.xcom_push(context, key="model_id", value=model_id)
+        self.log.info("Model creation complete. The model_id: %s.", model_id)
+
+        project_id = self.project_id or hook.project_id
+        TranslationModelLink.persist(
+            context=context,
+            task_instance=self,
+            dataset_id=self.dataset_id,
+            model_id=model_id,
+            project_id=project_id,
+        )
+        return result
+
+
+class TranslateModelsListOperator(GoogleCloudBaseOperator):
+    """
+    Get a list of native Google Cloud Translation models in a project.
+
+    Get project's list of `native` translation models, using API V3.
+    For more information on how to use this operator, take a look at the guide:
+    :ref:`howto/operator:TranslateModelsListOperator`.
+
+    :param project_id: ID of the Google Cloud project where dataset is located.
+        If not provided default project_id is used.
+    :param location: The location of the project.
+    :param retry: Designation of what errors, if any, should be retried.
+    :param timeout: The timeout for this request.
+    :param metadata:  Strings which should be sent along with the request as 
metadata.
+    :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+    :param impersonation_chain: Optional service account to impersonate using 
short-term
+        credentials, or chained list of accounts required to get the 
access_token
+        of the last account in the list, which will be impersonated in the 
request.
+        If set as a string, the account must grant the originating account
+        the Service Account Token Creator IAM role.
+        If set as a sequence, the identities from the list must grant
+        Service Account Token Creator IAM role to the directly preceding 
identity, with first
+        account from the list granting this role to the originating account 
(templated).
+    """
+
+    template_fields: Sequence[str] = (
+        "location",
+        "project_id",
+        "gcp_conn_id",
+        "impersonation_chain",
+    )
+
+    operator_extra_links = (TranslationModelsListLink(),)
+
+    def __init__(
+        self,
+        *,
+        project_id: str = PROVIDE_PROJECT_ID,
+        location: str,
+        metadata: Sequence[tuple[str, str]] = (),
+        timeout: float | _MethodDefault = DEFAULT,
+        retry: Retry | _MethodDefault = DEFAULT,
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: str | Sequence[str] | None = None,
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        self.project_id = project_id
+        self.location = location
+        self.metadata = metadata
+        self.timeout = timeout
+        self.retry = retry
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+
+    def execute(self, context: Context):
+        hook = TranslateHook(
+            gcp_conn_id=self.gcp_conn_id,
+            impersonation_chain=self.impersonation_chain,
+        )
+        project_id = self.project_id or hook.project_id
+        TranslationModelsListLink.persist(
+            context=context,
+            task_instance=self,
+            project_id=project_id,
+        )
+        self.log.info("Requesting models list")
+        results_pager = hook.list_models(
+            location=self.location,
+            project_id=self.project_id,
+            retry=self.retry,
+            timeout=self.timeout,
+            metadata=self.metadata,
+        )
+        result_ids = []
+        for model_item in results_pager:
+            model_data = type(model_item).to_dict(model_item)
+            model_id = hook.extract_object_id(model_data)
+            result_ids.append(model_id)
+        self.log.info("Fetching the models list complete. Model id-s: %s", 
result_ids)
+        return result_ids
+
+
+class TranslateDeleteModelOperator(GoogleCloudBaseOperator):
+    """
+    Delete translation model and all of its contents.
+
+    Deletes the translation model and it's data, using API V3.
+    For more information on how to use this operator, take a look at the guide:
+    :ref:`howto/operator:TranslateDeleteModelOperator`.
+
+    :param model_id: The model_id of target native model to be deleted.
+    :param location: The location of the project.
+    :param retry: Designation of what errors, if any, should be retried.
+    :param timeout: The timeout for this request.
+    :param metadata:  Strings which should be sent along with the request as 
metadata.
+    :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+    :param impersonation_chain: Optional service account to impersonate using 
short-term
+        credentials, or chained list of accounts required to get the 
access_token
+        of the last account in the list, which will be impersonated in the 
request.
+        If set as a string, the account must grant the originating account
+        the Service Account Token Creator IAM role.
+        If set as a sequence, the identities from the list must grant
+        Service Account Token Creator IAM role to the directly preceding 
identity, with first
+        account from the list granting this role to the originating account 
(templated).
+    """
+
+    template_fields: Sequence[str] = (
+        "model_id",
+        "location",
+        "project_id",
+        "gcp_conn_id",
+        "impersonation_chain",
+    )
+
+    def __init__(
+        self,
+        *,
+        model_id: str,
+        location: str,
+        project_id: str = PROVIDE_PROJECT_ID,
+        metadata: Sequence[tuple[str, str]] = (),
+        timeout: float | None = None,
+        retry: Retry | _MethodDefault = DEFAULT,
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: str | Sequence[str] | None = None,
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        self.model_id = model_id
+        self.project_id = project_id
+        self.location = location
+        self.metadata = metadata
+        self.timeout = timeout
+        self.retry = retry
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+
+    def execute(self, context: Context):
+        hook = TranslateHook(gcp_conn_id=self.gcp_conn_id, 
impersonation_chain=self.impersonation_chain)
+        self.log.info("Deleting the model %s...", self.model_id)
+        operation = hook.delete_model(
+            model_id=self.model_id,
+            location=self.location,
+            project_id=self.project_id,
+            retry=self.retry,
+            timeout=self.timeout,
+            metadata=self.metadata,
+        )
+        hook.wait_for_operation_done(operation=operation, timeout=self.timeout)
+        self.log.info("Model deletion complete!")
diff --git a/providers/src/airflow/providers/google/provider.yaml 
b/providers/src/airflow/providers/google/provider.yaml
index 96d2271f406..44eb99a43f6 100644
--- a/providers/src/airflow/providers/google/provider.yaml
+++ b/providers/src/airflow/providers/google/provider.yaml
@@ -1295,6 +1295,8 @@ extra-links:
   - airflow.providers.google.cloud.links.translate.TranslateTextBatchLink
   - airflow.providers.google.cloud.links.translate.TranslationNativeDatasetLink
   - airflow.providers.google.cloud.links.translate.TranslationDatasetsListLink
+  - airflow.providers.google.cloud.links.translate.TranslationModelLink
+  - airflow.providers.google.cloud.links.translate.TranslationModelsListLink
 
 
 secrets-backends:
diff --git a/providers/tests/google/cloud/operators/test_translate.py 
b/providers/tests/google/cloud/operators/test_translate.py
index 7af9803ea22..d1b6a9fa009 100644
--- a/providers/tests/google/cloud/operators/test_translate.py
+++ b/providers/tests/google/cloud/operators/test_translate.py
@@ -26,9 +26,12 @@ from airflow.providers.google.cloud.hooks.translate import 
TranslateHook
 from airflow.providers.google.cloud.operators.translate import (
     CloudTranslateTextOperator,
     TranslateCreateDatasetOperator,
+    TranslateCreateModelOperator,
     TranslateDatasetsListOperator,
     TranslateDeleteDatasetOperator,
+    TranslateDeleteModelOperator,
     TranslateImportDataOperator,
+    TranslateModelsListOperator,
     TranslateTextBatchOperator,
     TranslateTextOperator,
 )
@@ -39,6 +42,7 @@ GCP_CONN_ID = "google_cloud_default"
 IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"]
 PROJECT_ID = "test-project-id"
 DATASET_ID = "sample_ds_id"
+MODEL_ID = "sample_model_id"
 TIMEOUT_VALUE = 30
 
 
@@ -386,3 +390,155 @@ class TestTranslateDeleteData:
             metadata=(),
         )
         
wait_for_done.assert_called_once_with(operation=m_delete_method_result, 
timeout=TIMEOUT_VALUE)
+
+
+class TestTranslateModelCreate:
+    
@mock.patch("airflow.providers.google.cloud.links.translate.TranslationModelLink.persist")
+    
@mock.patch("airflow.providers.google.cloud.operators.translate.TranslateCreateModelOperator.xcom_push")
+    
@mock.patch("airflow.providers.google.cloud.operators.translate.TranslateHook")
+    def test_minimal_green_path(self, mock_hook, mock_xcom_push, 
mock_link_persist):
+        MODEL_DISPLAY_NAME = "model_display_name_01"
+        MODEL_CREATION_RESULT_SAMPLE = {
+            "display_name": MODEL_DISPLAY_NAME,
+            "name": 
f"projects/{PROJECT_ID}/locations/{LOCATION}/models/{MODEL_ID}",
+            "dataset": 
f"projects/{PROJECT_ID}/locations/{LOCATION}/datasets/{DATASET_ID}",
+            "source_language_code": "",
+            "target_language_code": "",
+            "create_time": "2024-11-15T14:05:00Z",
+            "update_time": "2024-11-16T01:09:03Z",
+            "test_example_count": 1000,
+            "train_example_count": 115,
+            "validate_example_count": 140,
+        }
+        sample_operation = mock.MagicMock()
+        sample_operation.result.return_value = 
automl_translation.Model(MODEL_CREATION_RESULT_SAMPLE)
+
+        mock_hook.return_value.create_model.return_value = sample_operation
+        mock_hook.return_value.wait_for_operation_result.side_effect = lambda 
operation: operation.result()
+        mock_hook.return_value.extract_object_id = 
TranslateHook.extract_object_id
+        op = TranslateCreateModelOperator(
+            task_id="task_id",
+            display_name=MODEL_DISPLAY_NAME,
+            dataset_id=DATASET_ID,
+            project_id=PROJECT_ID,
+            location=LOCATION,
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+            timeout=TIMEOUT_VALUE,
+            retry=None,
+        )
+        context = mock.MagicMock()
+        result = op.execute(context=context)
+        mock_hook.assert_called_once_with(
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+        mock_hook.return_value.create_model.assert_called_once_with(
+            display_name=MODEL_DISPLAY_NAME,
+            dataset_id=DATASET_ID,
+            project_id=PROJECT_ID,
+            location=LOCATION,
+            timeout=TIMEOUT_VALUE,
+            retry=None,
+            metadata=(),
+        )
+        mock_xcom_push.assert_called_once_with(context, key="model_id", 
value=MODEL_ID)
+        mock_link_persist.assert_called_once_with(
+            context=context,
+            task_instance=op,
+            model_id=MODEL_ID,
+            project_id=PROJECT_ID,
+            dataset_id=DATASET_ID,
+        )
+        assert result == MODEL_CREATION_RESULT_SAMPLE
+
+
+class TestTranslateListModels:
+    
@mock.patch("airflow.providers.google.cloud.links.translate.TranslationModelsListLink.persist")
+    
@mock.patch("airflow.providers.google.cloud.operators.translate.TranslateHook")
+    def test_minimal_green_path(self, mock_hook, mock_link_persist):
+        MODEL_ID_1 = "sample_model_1"
+        MODEL_ID_2 = "sample_model_2"
+        model_result_1 = automl_translation.Model(
+            dict(
+                display_name="model_1_display_name",
+                
name=f"projects/{PROJECT_ID}/locations/{LOCATION}/models/{MODEL_ID_1}",
+                
dataset=f"projects/{PROJECT_ID}/locations/{LOCATION}/datasets/ds_for_model_1",
+                source_language_code="en",
+                target_language_code="es",
+            )
+        )
+        model_result_2 = automl_translation.Model(
+            dict(
+                display_name="model_2_display_name",
+                
name=f"projects/{PROJECT_ID}/locations/{LOCATION}/models/{MODEL_ID_2}",
+                
dataset=f"projects/{PROJECT_ID}/locations/{LOCATION}/datasets/ds_for_model_2",
+                source_language_code="uk",
+                target_language_code="en",
+            )
+        )
+        mock_hook.return_value.list_models.return_value = [model_result_1, 
model_result_2]
+        mock_hook.return_value.extract_object_id = 
TranslateHook.extract_object_id
+
+        op = TranslateModelsListOperator(
+            task_id="task_id",
+            project_id=PROJECT_ID,
+            location=LOCATION,
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+            timeout=TIMEOUT_VALUE,
+            retry=DEFAULT,
+        )
+        context = mock.MagicMock()
+        result = op.execute(context=context)
+        mock_hook.assert_called_once_with(
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+        mock_hook.return_value.list_models.assert_called_once_with(
+            project_id=PROJECT_ID,
+            location=LOCATION,
+            timeout=TIMEOUT_VALUE,
+            retry=DEFAULT,
+            metadata=(),
+        )
+        assert result == [MODEL_ID_1, MODEL_ID_2]
+        mock_link_persist.assert_called_once_with(
+            context=context,
+            task_instance=op,
+            project_id=PROJECT_ID,
+        )
+
+
+class TestTranslateDeleteModel:
+    
@mock.patch("airflow.providers.google.cloud.operators.translate.TranslateHook")
+    def test_minimal_green_path(self, mock_hook):
+        m_delete_method_result = mock.MagicMock()
+        mock_hook.return_value.delete_model.return_value = 
m_delete_method_result
+        wait_for_done = mock_hook.return_value.wait_for_operation_done
+
+        op = TranslateDeleteModelOperator(
+            task_id="task_id",
+            model_id=MODEL_ID,
+            project_id=PROJECT_ID,
+            location=LOCATION,
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+            timeout=TIMEOUT_VALUE,
+            retry=DEFAULT,
+        )
+        context = mock.MagicMock()
+        op.execute(context=context)
+        mock_hook.assert_called_once_with(
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+        mock_hook.return_value.delete_model.assert_called_once_with(
+            model_id=MODEL_ID,
+            project_id=PROJECT_ID,
+            location=LOCATION,
+            timeout=TIMEOUT_VALUE,
+            retry=DEFAULT,
+            metadata=(),
+        )
+        
wait_for_done.assert_called_once_with(operation=m_delete_method_result, 
timeout=TIMEOUT_VALUE)
diff --git 
a/providers/tests/system/google/cloud/translate/example_translate_model.py 
b/providers/tests/system/google/cloud/translate/example_translate_model.py
new file mode 100644
index 00000000000..3514668fa2d
--- /dev/null
+++ b/providers/tests/system/google/cloud/translate/example_translate_model.py
@@ -0,0 +1,178 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example Airflow DAG that translates text in Google Cloud Translate using V3 
API version
+service in the Google Cloud.
+"""
+
+from __future__ import annotations
+
+import os
+from datetime import datetime
+
+from airflow.models.dag import DAG
+from airflow.providers.google.cloud.operators.gcs import 
GCSCreateBucketOperator, GCSDeleteBucketOperator
+from airflow.providers.google.cloud.operators.translate import (
+    TranslateCreateDatasetOperator,
+    TranslateCreateModelOperator,
+    TranslateDeleteDatasetOperator,
+    TranslateDeleteModelOperator,
+    TranslateImportDataOperator,
+    TranslateModelsListOperator,
+    TranslateTextOperator,
+)
+from airflow.providers.google.cloud.transfers.gcs_to_gcs import 
GCSToGCSOperator
+from airflow.utils.trigger_rule import TriggerRule
+
+DAG_ID = "gcp_translate_automl_native_model"
+PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default")
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default")
+REGION = "us-central1"
+RESOURCE_DATA_BUCKET = "airflow-system-tests-resources"
+DATA_SAMPLE_GCS_BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}".replace("_", "-")
+DATA_FILE_NAME = "import_en-es_short.tsv"
+RESOURCE_PATH = f"V3_translate/create_ds/import_data/{DATA_FILE_NAME}"
+COPY_DATA_PATH = 
f"gs://{RESOURCE_DATA_BUCKET}/V3_translate/create_ds/import_data/{DATA_FILE_NAME}"
+DST_PATH = f"translate/import/{DATA_FILE_NAME}"
+DATASET_DATA_PATH = f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/{DST_PATH}"
+DATASET = {
+    "display_name": f"op_ds_native{DAG_ID}_{ENV_ID}",
+    "source_language_code": "es",
+    "target_language_code": "en",
+}
+
+
+with DAG(
+    DAG_ID,
+    schedule="@once",  # Override to match your needs
+    start_date=datetime(2024, 11, 1),
+    catchup=False,
+    tags=[
+        "example",
+        "translate_model",
+    ],
+) as dag:
+    create_bucket = GCSCreateBucketOperator(
+        task_id="create_bucket",
+        bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME,
+        storage_class="REGIONAL",
+        location=REGION,
+    )
+    copy_dataset_source_tsv = GCSToGCSOperator(
+        task_id="copy_dataset_file",
+        source_bucket=RESOURCE_DATA_BUCKET,
+        source_object=RESOURCE_PATH,
+        destination_bucket=DATA_SAMPLE_GCS_BUCKET_NAME,
+        destination_object=DST_PATH,
+    )
+
+    create_dataset_op = TranslateCreateDatasetOperator(
+        task_id="translate_v3_ds_create",
+        dataset=DATASET,
+        project_id=PROJECT_ID,
+        location=REGION,
+    )
+
+    import_ds_data_op = TranslateImportDataOperator(
+        task_id="translate_v3_ds_import_data",
+        dataset_id=create_dataset_op.output["dataset_id"],
+        input_config={
+            "input_files": [{"usage": "UNASSIGNED", "gcs_source": 
{"input_uri": DATASET_DATA_PATH}}]
+        },
+        project_id=PROJECT_ID,
+        location=REGION,
+    )
+
+    # [START howto_operator_translate_automl_create_model]
+    create_model = TranslateCreateModelOperator(
+        task_id="translate_v3_model_create",
+        display_name=f"native_model_{ENV_ID}"[:32].replace("-", "_"),
+        dataset_id=create_dataset_op.output["dataset_id"],
+        project_id=PROJECT_ID,
+        location=REGION,
+    )
+    # [END howto_operator_translate_automl_create_model]
+
+    # [START howto_operator_translate_automl_list_models]
+    list_models = TranslateModelsListOperator(
+        task_id="translate_v3_list_models",
+        project_id=PROJECT_ID,
+        location=REGION,
+    )
+    # [END howto_operator_translate_automl_list_models]
+
+    model_id = create_model.output["model_id"]
+
+    translate_text_with_model = TranslateTextOperator(
+        task_id="translate_v3_op",
+        contents=["Hola!", "Puedes traerme una taza de café, por favor?"],
+        # AutoML model format
+        model=f"projects/{PROJECT_ID}/locations/{REGION}/models/{model_id}",
+        source_language_code="es",
+        target_language_code="en",
+    )
+
+    # [START howto_operator_translate_automl_delete_model]
+    delete_model = TranslateDeleteModelOperator(
+        task_id="translate_v3_automl_delete_model",
+        model_id=model_id,
+        project_id=PROJECT_ID,
+        location=REGION,
+    )
+    # [END howto_operator_translate_automl_delete_model]
+
+    delete_ds_op = TranslateDeleteDatasetOperator(
+        task_id="translate_v3_ds_delete",
+        dataset_id=create_dataset_op.output["dataset_id"],
+        project_id=PROJECT_ID,
+        location=REGION,
+    )
+    # [END howto_operator_translate_automl_delete_dataset]
+
+    delete_bucket = GCSDeleteBucketOperator(
+        task_id="delete_bucket",
+        bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME,
+        trigger_rule=TriggerRule.ALL_DONE,
+    )
+
+    (
+        # TEST SETUP
+        [create_bucket >> copy_dataset_source_tsv]
+        >> create_dataset_op
+        >> import_ds_data_op
+        # TEST BODY
+        >> create_model
+        >> list_models
+        >> translate_text_with_model
+        >> delete_model
+        # TEST TEARDOWN
+        >> delete_ds_op
+        >> delete_bucket
+    )
+
+    from tests_common.test_utils.watcher import watcher
+
+    # This test needs watcher in order to properly mark success/failure
+    # when "tearDown" task with trigger rule is part of the DAG
+    list(dag.tasks) >> watcher()
+
+
+from tests_common.test_utils.system_tests import get_test_run  # noqa: E402
+
+# Needed to run the example DAG with pytest (see: 
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)


Reply via email to