This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 7717985775f Introduce the translation API v3 (advanced) models
operators. (#44627)
7717985775f is described below
commit 7717985775fa42f5722ea1b624c09ac1e2fcfa75
Author: olegkachur-e <[email protected]>
AuthorDate: Fri Dec 6 14:16:37 2024 +0100
Introduce the translation API v3 (advanced) models operators. (#44627)
- TranslateCreateModelOperator
- TranslateModelsListOperator
- TranslateDeleteModelOperator
More details on using AutoML translation:
https://cloud.google.com/translate/docs/advanced/automl-beginner.
Co-authored-by: Oleg Kachur <[email protected]>
---
.../operators/cloud/translate.rst | 63 +++++
docs/spelling_wordlist.txt | 1 +
.../providers/google/cloud/hooks/translate.py | 154 +++++++++++++
.../providers/google/cloud/links/translate.py | 63 +++++
.../providers/google/cloud/operators/translate.py | 255 +++++++++++++++++++++
.../src/airflow/providers/google/provider.yaml | 2 +
.../tests/google/cloud/operators/test_translate.py | 156 +++++++++++++
.../cloud/translate/example_translate_model.py | 178 ++++++++++++++
8 files changed, 872 insertions(+)
diff --git a/docs/apache-airflow-providers-google/operators/cloud/translate.rst
b/docs/apache-airflow-providers-google/operators/cloud/translate.rst
index d56fac26dbe..5bda3d9085a 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/translate.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/translate.rst
@@ -184,6 +184,69 @@ Basic usage of the operator:
:end-before: [END howto_operator_translate_automl_delete_dataset]
+.. _howto/operator:TranslateCreateModelOperator:
+
+TranslateCreateModelOperator
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+Create a native translation model using Cloud Translate API (Advanced V3).
+
+For parameter definition, take a look at
+:class:`~airflow.providers.google.cloud.operators.translate.TranslateCreateModelOperator`
+
+Using the operator
+""""""""""""""""""
+
+Basic usage of the operator:
+
+.. exampleinclude::
/../../providers/tests/system/google/cloud/translate/example_translate_model.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_translate_automl_create_model]
+ :end-before: [END howto_operator_translate_automl_create_model]
+
+
+.. _howto/operator:TranslateModelsListOperator:
+
+TranslateModelsListOperator
+^^^^^^^^^^^^^^^^^^^^^^^^^^^
+Get list of native translation models using Cloud Translate API (Advanced V3).
+
+For parameter definition, take a look at
+:class:`~airflow.providers.google.cloud.operators.translate.TranslateModelsListOperator`
+
+Using the operator
+""""""""""""""""""
+
+Basic usage of the operator:
+
+.. exampleinclude::
/../../providers/tests/system/google/cloud/translate/example_translate_model.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_translate_automl_list_models]
+ :end-before: [END howto_operator_translate_automl_list_models]
+
+
+.. _howto/operator:TranslateDeleteModelOperator:
+
+TranslateDeleteModelOperator
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+Delete a native translation model using Cloud Translate API (Advanced V3).
+
+For parameter definition, take a look at
+:class:`~airflow.providers.google.cloud.operators.translate.TranslateDeleteModelOperator`
+
+Using the operator
+""""""""""""""""""
+
+Basic usage of the operator:
+
+.. exampleinclude::
/../../providers/tests/system/google/cloud/translate/example_translate_model.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_translate_automl_delete_model]
+ :end-before: [END howto_operator_translate_automl_delete_model]
+
+
More information
""""""""""""""""""
See:
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index e02e6cee380..69656b6d086 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -970,6 +970,7 @@ linux
ListDatasetsPager
ListGenerator
ListInfoTypesResponse
+ListModelsPager
ListSecretsPager
Liveness
liveness
diff --git a/providers/src/airflow/providers/google/cloud/hooks/translate.py
b/providers/src/airflow/providers/google/cloud/hooks/translate.py
index 4901e4a9914..43e0c15774b 100644
--- a/providers/src/airflow/providers/google/cloud/hooks/translate.py
+++ b/providers/src/airflow/providers/google/cloud/hooks/translate.py
@@ -560,3 +560,157 @@ class TranslateHook(GoogleBaseHook):
metadata=metadata,
)
return result
+
+ def create_model(
+ self,
+ dataset_id: str,
+ display_name: str,
+ project_id: str,
+ location: str,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> Operation:
+ """
+ Create the native model by training on translation dataset provided.
+
+ :param dataset_id: ID of dataset to be used for model training.
+ :param display_name: Display name of the model trained.
+ A-Z and a-z, underscores (_), and ASCII digits 0-9.
+ :param project_id: ID of the Google Cloud project where dataset is
located. If not provided
+ default project_id is used.
+ :param location: The location of the project.
+ :param retry: A retry object used to retry requests. If `None` is
specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the
request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual
attempt.
+ :param metadata: Additional metadata that is provided to the method.
+
+ :return: `Operation` object with the model creation results, when
finished.
+ """
+ client = self.get_client()
+ project_id = project_id or self.project_id
+ parent = f"projects/{project_id}/locations/{location}"
+ dataset =
f"projects/{project_id}/locations/{location}/datasets/{dataset_id}"
+ result = client.create_model(
+ request={
+ "parent": parent,
+ "model": {
+ "display_name": display_name,
+ "dataset": dataset,
+ },
+ },
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
+
+ def get_model(
+ self,
+ model_id: str,
+ project_id: str,
+ location: str,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | _MethodDefault = DEFAULT,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> automl_translation.Model:
+ """
+ Retrieve the dataset for the given model_id.
+
+ :param model_id: ID of translation model to be retrieved.
+ :param project_id: ID of the Google Cloud project where dataset is
located. If not provided
+ default project_id is used.
+ :param location: The location of the project.
+ :param retry: A retry object used to retry requests. If `None` is
specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the
request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual
attempt.
+ :param metadata: Additional metadata that is provided to the method.
+
+ :return: `automl_translation.Model` instance.
+ """
+ client = self.get_client()
+ name = f"projects/{project_id}/locations/{location}/models/{model_id}"
+ return client.get_model(
+ request={"name": name},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ def list_models(
+ self,
+ project_id: str,
+ location: str,
+ filter_str: str | None = None,
+ page_size: int | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | _MethodDefault = DEFAULT,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> pagers.ListModelsPager:
+ """
+ List translation models in a project.
+
+ :param project_id: ID of the Google Cloud project where models are
located. If not provided
+ default project_id is used.
+ :param location: The location of the project.
+ :param filter_str: An optional expression for filtering the models
that will
+ be returned. Supported filter: ``dataset_id=${dataset_id}``.
+ :param page_size: Optional custom page size value. The server can
+ return fewer results than requested.
+ :param retry: A retry object used to retry requests. If `None` is
specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the
request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual
attempt.
+ :param metadata: Additional metadata that is provided to the method.
+
+ :return: ``pagers.ListDatasetsPager`` instance, iterable object to
retrieve the datasets list.
+ """
+ client = self.get_client()
+ parent = f"projects/{project_id}/locations/{location}"
+ result = client.list_models(
+ request={
+ "parent": parent,
+ "filter": filter_str,
+ "page_size": page_size,
+ },
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
+
+ def delete_model(
+ self,
+ model_id: str,
+ project_id: str,
+ location: str,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> Operation:
+ """
+ Delete the translation model and all of its contents.
+
+ :param model_id: ID of model to be deleted.
+ :param project_id: ID of the Google Cloud project where dataset is
located. If not provided
+ default project_id is used.
+ :param location: The location of the project.
+ :param retry: A retry object used to retry requests. If `None` is
specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the
request to complete. Note that if
+ `retry` is specified, the timeout applies to each individual
attempt.
+ :param metadata: Additional metadata that is provided to the method.
+
+ :return: `Operation` object with dataset deletion results, when
finished.
+ """
+ client = self.get_client()
+ name = f"projects/{project_id}/locations/{location}/models/{model_id}"
+ result = client.delete_model(
+ request={"name": name},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
diff --git a/providers/src/airflow/providers/google/cloud/links/translate.py
b/providers/src/airflow/providers/google/cloud/links/translate.py
index 0d1489ddcfc..55db2650838 100644
--- a/providers/src/airflow/providers/google/cloud/links/translate.py
+++ b/providers/src/airflow/providers/google/cloud/links/translate.py
@@ -50,6 +50,12 @@ TRANSLATION_NATIVE_DATASET_LINK = (
)
TRANSLATION_NATIVE_LIST_LINK = TRANSLATION_BASE_LINK +
"/datasets?project={project_id}"
+TRANSLATION_NATIVE_MODEL_LINK = (
+ TRANSLATION_BASE_LINK
+ +
"/locations/{location}/datasets/{dataset_id}/evaluate;modelId={model_id}?project={project_id}"
+)
+TRANSLATION_MODELS_LIST_LINK = TRANSLATION_BASE_LINK +
"/models/list?project={project_id}"
+
class TranslationLegacyDatasetLink(BaseGoogleLink):
"""
@@ -270,3 +276,60 @@ class TranslationDatasetsListLink(BaseGoogleLink):
"project_id": project_id,
},
)
+
+
+class TranslationModelLink(BaseGoogleLink):
+ """
+ Helper class for constructing Translation Model link.
+
+ Link for legacy and native models.
+ """
+
+ name = "Translation Model"
+ key = "translation_model"
+ format_str = TRANSLATION_NATIVE_MODEL_LINK
+
+ @staticmethod
+ def persist(
+ context: Context,
+ task_instance,
+ dataset_id: str,
+ model_id: str,
+ project_id: str,
+ ):
+ task_instance.xcom_push(
+ context,
+ key=TranslationLegacyModelLink.key,
+ value={
+ "location": task_instance.location,
+ "dataset_id": dataset_id,
+ "model_id": model_id,
+ "project_id": project_id,
+ },
+ )
+
+
+class TranslationModelsListLink(BaseGoogleLink):
+ """
+ Helper class for constructing Translation Models List link.
+
+ Both legacy and native models are available under this link.
+ """
+
+ name = "Translation Models List"
+ key = "translation_models_list"
+ format_str = TRANSLATION_MODELS_LIST_LINK
+
+ @staticmethod
+ def persist(
+ context: Context,
+ task_instance,
+ project_id: str,
+ ):
+ task_instance.xcom_push(
+ context,
+ key=TranslationModelsListLink.key,
+ value={
+ "project_id": project_id,
+ },
+ )
diff --git
a/providers/src/airflow/providers/google/cloud/operators/translate.py
b/providers/src/airflow/providers/google/cloud/operators/translate.py
index 193433deaa7..4c04e9a7bc5 100644
--- a/providers/src/airflow/providers/google/cloud/operators/translate.py
+++ b/providers/src/airflow/providers/google/cloud/operators/translate.py
@@ -30,6 +30,8 @@ from airflow.providers.google.cloud.hooks.translate import
CloudTranslateHook, T
from airflow.providers.google.cloud.links.translate import (
TranslateTextBatchLink,
TranslationDatasetsListLink,
+ TranslationModelLink,
+ TranslationModelsListLink,
TranslationNativeDatasetLink,
)
from airflow.providers.google.cloud.operators.cloud_base import
GoogleCloudBaseOperator
@@ -723,3 +725,256 @@ class
TranslateDeleteDatasetOperator(GoogleCloudBaseOperator):
)
hook.wait_for_operation_done(operation=operation, timeout=self.timeout)
self.log.info("Dataset deletion complete!")
+
+
+class TranslateCreateModelOperator(GoogleCloudBaseOperator):
+ """
+ Creates a Google Cloud Translate model.
+
+ Creates a `native` translation model, using API V3.
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:TranslateCreateModelOperator`.
+
+ :param dataset_id: The dataset id used for model training.
+ :param project_id: ID of the Google Cloud project where dataset is located.
+ If not provided default project_id is used.
+ :param location: The location of the project.
+ :param retry: Designation of what errors, if any, should be retried.
+ :param timeout: The timeout for this request.
+ :param metadata: Strings which should be sent along with the request as
metadata.
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ """
+
+ template_fields: Sequence[str] = (
+ "dataset_id",
+ "location",
+ "project_id",
+ "gcp_conn_id",
+ "impersonation_chain",
+ )
+
+ operator_extra_links = (TranslationModelLink(),)
+
+ def __init__(
+ self,
+ *,
+ project_id: str = PROVIDE_PROJECT_ID,
+ location: str,
+ dataset_id: str,
+ display_name: str,
+ timeout: float | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ gcp_conn_id: str = "google_cloud_default",
+ metadata: Sequence[tuple[str, str]] = (),
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.project_id = project_id
+ self.location = location
+ self.dataset_id = dataset_id
+ self.display_name = display_name
+ self.metadata = metadata
+ self.timeout = timeout
+ self.retry = retry
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+
+ def execute(self, context: Context) -> str:
+ hook = TranslateHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+ self.log.info("Model creation started, dataset_id %s...",
self.dataset_id)
+ try:
+ result_operation = hook.create_model(
+ dataset_id=self.dataset_id,
+ display_name=self.display_name,
+ location=self.location,
+ project_id=self.project_id,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+ except GoogleAPICallError as e:
+ self.log.error("Error submitting create_model operation ")
+ raise AirflowException(e)
+
+ self.log.info("Training has started")
+ hook.wait_for_operation_done(operation=result_operation)
+ result = hook.wait_for_operation_result(operation=result_operation)
+ result = type(result).to_dict(result)
+ model_id = hook.extract_object_id(result)
+ self.xcom_push(context, key="model_id", value=model_id)
+ self.log.info("Model creation complete. The model_id: %s.", model_id)
+
+ project_id = self.project_id or hook.project_id
+ TranslationModelLink.persist(
+ context=context,
+ task_instance=self,
+ dataset_id=self.dataset_id,
+ model_id=model_id,
+ project_id=project_id,
+ )
+ return result
+
+
+class TranslateModelsListOperator(GoogleCloudBaseOperator):
+ """
+ Get a list of native Google Cloud Translation models in a project.
+
+ Get project's list of `native` translation models, using API V3.
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:TranslateModelsListOperator`.
+
+ :param project_id: ID of the Google Cloud project where dataset is located.
+ If not provided default project_id is used.
+ :param location: The location of the project.
+ :param retry: Designation of what errors, if any, should be retried.
+ :param timeout: The timeout for this request.
+ :param metadata: Strings which should be sent along with the request as
metadata.
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ """
+
+ template_fields: Sequence[str] = (
+ "location",
+ "project_id",
+ "gcp_conn_id",
+ "impersonation_chain",
+ )
+
+ operator_extra_links = (TranslationModelsListLink(),)
+
+ def __init__(
+ self,
+ *,
+ project_id: str = PROVIDE_PROJECT_ID,
+ location: str,
+ metadata: Sequence[tuple[str, str]] = (),
+ timeout: float | _MethodDefault = DEFAULT,
+ retry: Retry | _MethodDefault = DEFAULT,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.project_id = project_id
+ self.location = location
+ self.metadata = metadata
+ self.timeout = timeout
+ self.retry = retry
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+
+ def execute(self, context: Context):
+ hook = TranslateHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+ project_id = self.project_id or hook.project_id
+ TranslationModelsListLink.persist(
+ context=context,
+ task_instance=self,
+ project_id=project_id,
+ )
+ self.log.info("Requesting models list")
+ results_pager = hook.list_models(
+ location=self.location,
+ project_id=self.project_id,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+ result_ids = []
+ for model_item in results_pager:
+ model_data = type(model_item).to_dict(model_item)
+ model_id = hook.extract_object_id(model_data)
+ result_ids.append(model_id)
+ self.log.info("Fetching the models list complete. Model id-s: %s",
result_ids)
+ return result_ids
+
+
+class TranslateDeleteModelOperator(GoogleCloudBaseOperator):
+ """
+ Delete translation model and all of its contents.
+
+ Deletes the translation model and it's data, using API V3.
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:TranslateDeleteModelOperator`.
+
+ :param model_id: The model_id of target native model to be deleted.
+ :param location: The location of the project.
+ :param retry: Designation of what errors, if any, should be retried.
+ :param timeout: The timeout for this request.
+ :param metadata: Strings which should be sent along with the request as
metadata.
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ """
+
+ template_fields: Sequence[str] = (
+ "model_id",
+ "location",
+ "project_id",
+ "gcp_conn_id",
+ "impersonation_chain",
+ )
+
+ def __init__(
+ self,
+ *,
+ model_id: str,
+ location: str,
+ project_id: str = PROVIDE_PROJECT_ID,
+ metadata: Sequence[tuple[str, str]] = (),
+ timeout: float | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.model_id = model_id
+ self.project_id = project_id
+ self.location = location
+ self.metadata = metadata
+ self.timeout = timeout
+ self.retry = retry
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+
+ def execute(self, context: Context):
+ hook = TranslateHook(gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain)
+ self.log.info("Deleting the model %s...", self.model_id)
+ operation = hook.delete_model(
+ model_id=self.model_id,
+ location=self.location,
+ project_id=self.project_id,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+ hook.wait_for_operation_done(operation=operation, timeout=self.timeout)
+ self.log.info("Model deletion complete!")
diff --git a/providers/src/airflow/providers/google/provider.yaml
b/providers/src/airflow/providers/google/provider.yaml
index 96d2271f406..44eb99a43f6 100644
--- a/providers/src/airflow/providers/google/provider.yaml
+++ b/providers/src/airflow/providers/google/provider.yaml
@@ -1295,6 +1295,8 @@ extra-links:
- airflow.providers.google.cloud.links.translate.TranslateTextBatchLink
- airflow.providers.google.cloud.links.translate.TranslationNativeDatasetLink
- airflow.providers.google.cloud.links.translate.TranslationDatasetsListLink
+ - airflow.providers.google.cloud.links.translate.TranslationModelLink
+ - airflow.providers.google.cloud.links.translate.TranslationModelsListLink
secrets-backends:
diff --git a/providers/tests/google/cloud/operators/test_translate.py
b/providers/tests/google/cloud/operators/test_translate.py
index 7af9803ea22..d1b6a9fa009 100644
--- a/providers/tests/google/cloud/operators/test_translate.py
+++ b/providers/tests/google/cloud/operators/test_translate.py
@@ -26,9 +26,12 @@ from airflow.providers.google.cloud.hooks.translate import
TranslateHook
from airflow.providers.google.cloud.operators.translate import (
CloudTranslateTextOperator,
TranslateCreateDatasetOperator,
+ TranslateCreateModelOperator,
TranslateDatasetsListOperator,
TranslateDeleteDatasetOperator,
+ TranslateDeleteModelOperator,
TranslateImportDataOperator,
+ TranslateModelsListOperator,
TranslateTextBatchOperator,
TranslateTextOperator,
)
@@ -39,6 +42,7 @@ GCP_CONN_ID = "google_cloud_default"
IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"]
PROJECT_ID = "test-project-id"
DATASET_ID = "sample_ds_id"
+MODEL_ID = "sample_model_id"
TIMEOUT_VALUE = 30
@@ -386,3 +390,155 @@ class TestTranslateDeleteData:
metadata=(),
)
wait_for_done.assert_called_once_with(operation=m_delete_method_result,
timeout=TIMEOUT_VALUE)
+
+
+class TestTranslateModelCreate:
+
@mock.patch("airflow.providers.google.cloud.links.translate.TranslationModelLink.persist")
+
@mock.patch("airflow.providers.google.cloud.operators.translate.TranslateCreateModelOperator.xcom_push")
+
@mock.patch("airflow.providers.google.cloud.operators.translate.TranslateHook")
+ def test_minimal_green_path(self, mock_hook, mock_xcom_push,
mock_link_persist):
+ MODEL_DISPLAY_NAME = "model_display_name_01"
+ MODEL_CREATION_RESULT_SAMPLE = {
+ "display_name": MODEL_DISPLAY_NAME,
+ "name":
f"projects/{PROJECT_ID}/locations/{LOCATION}/models/{MODEL_ID}",
+ "dataset":
f"projects/{PROJECT_ID}/locations/{LOCATION}/datasets/{DATASET_ID}",
+ "source_language_code": "",
+ "target_language_code": "",
+ "create_time": "2024-11-15T14:05:00Z",
+ "update_time": "2024-11-16T01:09:03Z",
+ "test_example_count": 1000,
+ "train_example_count": 115,
+ "validate_example_count": 140,
+ }
+ sample_operation = mock.MagicMock()
+ sample_operation.result.return_value =
automl_translation.Model(MODEL_CREATION_RESULT_SAMPLE)
+
+ mock_hook.return_value.create_model.return_value = sample_operation
+ mock_hook.return_value.wait_for_operation_result.side_effect = lambda
operation: operation.result()
+ mock_hook.return_value.extract_object_id =
TranslateHook.extract_object_id
+ op = TranslateCreateModelOperator(
+ task_id="task_id",
+ display_name=MODEL_DISPLAY_NAME,
+ dataset_id=DATASET_ID,
+ project_id=PROJECT_ID,
+ location=LOCATION,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ timeout=TIMEOUT_VALUE,
+ retry=None,
+ )
+ context = mock.MagicMock()
+ result = op.execute(context=context)
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+ mock_hook.return_value.create_model.assert_called_once_with(
+ display_name=MODEL_DISPLAY_NAME,
+ dataset_id=DATASET_ID,
+ project_id=PROJECT_ID,
+ location=LOCATION,
+ timeout=TIMEOUT_VALUE,
+ retry=None,
+ metadata=(),
+ )
+ mock_xcom_push.assert_called_once_with(context, key="model_id",
value=MODEL_ID)
+ mock_link_persist.assert_called_once_with(
+ context=context,
+ task_instance=op,
+ model_id=MODEL_ID,
+ project_id=PROJECT_ID,
+ dataset_id=DATASET_ID,
+ )
+ assert result == MODEL_CREATION_RESULT_SAMPLE
+
+
+class TestTranslateListModels:
+
@mock.patch("airflow.providers.google.cloud.links.translate.TranslationModelsListLink.persist")
+
@mock.patch("airflow.providers.google.cloud.operators.translate.TranslateHook")
+ def test_minimal_green_path(self, mock_hook, mock_link_persist):
+ MODEL_ID_1 = "sample_model_1"
+ MODEL_ID_2 = "sample_model_2"
+ model_result_1 = automl_translation.Model(
+ dict(
+ display_name="model_1_display_name",
+
name=f"projects/{PROJECT_ID}/locations/{LOCATION}/models/{MODEL_ID_1}",
+
dataset=f"projects/{PROJECT_ID}/locations/{LOCATION}/datasets/ds_for_model_1",
+ source_language_code="en",
+ target_language_code="es",
+ )
+ )
+ model_result_2 = automl_translation.Model(
+ dict(
+ display_name="model_2_display_name",
+
name=f"projects/{PROJECT_ID}/locations/{LOCATION}/models/{MODEL_ID_2}",
+
dataset=f"projects/{PROJECT_ID}/locations/{LOCATION}/datasets/ds_for_model_2",
+ source_language_code="uk",
+ target_language_code="en",
+ )
+ )
+ mock_hook.return_value.list_models.return_value = [model_result_1,
model_result_2]
+ mock_hook.return_value.extract_object_id =
TranslateHook.extract_object_id
+
+ op = TranslateModelsListOperator(
+ task_id="task_id",
+ project_id=PROJECT_ID,
+ location=LOCATION,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ timeout=TIMEOUT_VALUE,
+ retry=DEFAULT,
+ )
+ context = mock.MagicMock()
+ result = op.execute(context=context)
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+ mock_hook.return_value.list_models.assert_called_once_with(
+ project_id=PROJECT_ID,
+ location=LOCATION,
+ timeout=TIMEOUT_VALUE,
+ retry=DEFAULT,
+ metadata=(),
+ )
+ assert result == [MODEL_ID_1, MODEL_ID_2]
+ mock_link_persist.assert_called_once_with(
+ context=context,
+ task_instance=op,
+ project_id=PROJECT_ID,
+ )
+
+
+class TestTranslateDeleteModel:
+
@mock.patch("airflow.providers.google.cloud.operators.translate.TranslateHook")
+ def test_minimal_green_path(self, mock_hook):
+ m_delete_method_result = mock.MagicMock()
+ mock_hook.return_value.delete_model.return_value =
m_delete_method_result
+ wait_for_done = mock_hook.return_value.wait_for_operation_done
+
+ op = TranslateDeleteModelOperator(
+ task_id="task_id",
+ model_id=MODEL_ID,
+ project_id=PROJECT_ID,
+ location=LOCATION,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ timeout=TIMEOUT_VALUE,
+ retry=DEFAULT,
+ )
+ context = mock.MagicMock()
+ op.execute(context=context)
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+ mock_hook.return_value.delete_model.assert_called_once_with(
+ model_id=MODEL_ID,
+ project_id=PROJECT_ID,
+ location=LOCATION,
+ timeout=TIMEOUT_VALUE,
+ retry=DEFAULT,
+ metadata=(),
+ )
+
wait_for_done.assert_called_once_with(operation=m_delete_method_result,
timeout=TIMEOUT_VALUE)
diff --git
a/providers/tests/system/google/cloud/translate/example_translate_model.py
b/providers/tests/system/google/cloud/translate/example_translate_model.py
new file mode 100644
index 00000000000..3514668fa2d
--- /dev/null
+++ b/providers/tests/system/google/cloud/translate/example_translate_model.py
@@ -0,0 +1,178 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example Airflow DAG that translates text in Google Cloud Translate using V3
API version
+service in the Google Cloud.
+"""
+
+from __future__ import annotations
+
+import os
+from datetime import datetime
+
+from airflow.models.dag import DAG
+from airflow.providers.google.cloud.operators.gcs import
GCSCreateBucketOperator, GCSDeleteBucketOperator
+from airflow.providers.google.cloud.operators.translate import (
+ TranslateCreateDatasetOperator,
+ TranslateCreateModelOperator,
+ TranslateDeleteDatasetOperator,
+ TranslateDeleteModelOperator,
+ TranslateImportDataOperator,
+ TranslateModelsListOperator,
+ TranslateTextOperator,
+)
+from airflow.providers.google.cloud.transfers.gcs_to_gcs import
GCSToGCSOperator
+from airflow.utils.trigger_rule import TriggerRule
+
+DAG_ID = "gcp_translate_automl_native_model"
+PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default")
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default")
+REGION = "us-central1"
+RESOURCE_DATA_BUCKET = "airflow-system-tests-resources"
+DATA_SAMPLE_GCS_BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}".replace("_", "-")
+DATA_FILE_NAME = "import_en-es_short.tsv"
+RESOURCE_PATH = f"V3_translate/create_ds/import_data/{DATA_FILE_NAME}"
+COPY_DATA_PATH =
f"gs://{RESOURCE_DATA_BUCKET}/V3_translate/create_ds/import_data/{DATA_FILE_NAME}"
+DST_PATH = f"translate/import/{DATA_FILE_NAME}"
+DATASET_DATA_PATH = f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/{DST_PATH}"
+DATASET = {
+ "display_name": f"op_ds_native{DAG_ID}_{ENV_ID}",
+ "source_language_code": "es",
+ "target_language_code": "en",
+}
+
+
+with DAG(
+ DAG_ID,
+ schedule="@once", # Override to match your needs
+ start_date=datetime(2024, 11, 1),
+ catchup=False,
+ tags=[
+ "example",
+ "translate_model",
+ ],
+) as dag:
+ create_bucket = GCSCreateBucketOperator(
+ task_id="create_bucket",
+ bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME,
+ storage_class="REGIONAL",
+ location=REGION,
+ )
+ copy_dataset_source_tsv = GCSToGCSOperator(
+ task_id="copy_dataset_file",
+ source_bucket=RESOURCE_DATA_BUCKET,
+ source_object=RESOURCE_PATH,
+ destination_bucket=DATA_SAMPLE_GCS_BUCKET_NAME,
+ destination_object=DST_PATH,
+ )
+
+ create_dataset_op = TranslateCreateDatasetOperator(
+ task_id="translate_v3_ds_create",
+ dataset=DATASET,
+ project_id=PROJECT_ID,
+ location=REGION,
+ )
+
+ import_ds_data_op = TranslateImportDataOperator(
+ task_id="translate_v3_ds_import_data",
+ dataset_id=create_dataset_op.output["dataset_id"],
+ input_config={
+ "input_files": [{"usage": "UNASSIGNED", "gcs_source":
{"input_uri": DATASET_DATA_PATH}}]
+ },
+ project_id=PROJECT_ID,
+ location=REGION,
+ )
+
+ # [START howto_operator_translate_automl_create_model]
+ create_model = TranslateCreateModelOperator(
+ task_id="translate_v3_model_create",
+ display_name=f"native_model_{ENV_ID}"[:32].replace("-", "_"),
+ dataset_id=create_dataset_op.output["dataset_id"],
+ project_id=PROJECT_ID,
+ location=REGION,
+ )
+ # [END howto_operator_translate_automl_create_model]
+
+ # [START howto_operator_translate_automl_list_models]
+ list_models = TranslateModelsListOperator(
+ task_id="translate_v3_list_models",
+ project_id=PROJECT_ID,
+ location=REGION,
+ )
+ # [END howto_operator_translate_automl_list_models]
+
+ model_id = create_model.output["model_id"]
+
+ translate_text_with_model = TranslateTextOperator(
+ task_id="translate_v3_op",
+ contents=["Hola!", "Puedes traerme una taza de café, por favor?"],
+ # AutoML model format
+ model=f"projects/{PROJECT_ID}/locations/{REGION}/models/{model_id}",
+ source_language_code="es",
+ target_language_code="en",
+ )
+
+ # [START howto_operator_translate_automl_delete_model]
+ delete_model = TranslateDeleteModelOperator(
+ task_id="translate_v3_automl_delete_model",
+ model_id=model_id,
+ project_id=PROJECT_ID,
+ location=REGION,
+ )
+ # [END howto_operator_translate_automl_delete_model]
+
+ delete_ds_op = TranslateDeleteDatasetOperator(
+ task_id="translate_v3_ds_delete",
+ dataset_id=create_dataset_op.output["dataset_id"],
+ project_id=PROJECT_ID,
+ location=REGION,
+ )
+ # [END howto_operator_translate_automl_delete_dataset]
+
+ delete_bucket = GCSDeleteBucketOperator(
+ task_id="delete_bucket",
+ bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME,
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+
+ (
+ # TEST SETUP
+ [create_bucket >> copy_dataset_source_tsv]
+ >> create_dataset_op
+ >> import_ds_data_op
+ # TEST BODY
+ >> create_model
+ >> list_models
+ >> translate_text_with_model
+ >> delete_model
+ # TEST TEARDOWN
+ >> delete_ds_op
+ >> delete_bucket
+ )
+
+ from tests_common.test_utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "tearDown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+
+from tests_common.test_utils.system_tests import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)