This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 88d5725c6b Add AutoML templating tests (#38170)
88d5725c6b is described below

commit 88d5725c6b24bc7d8a1c8201912d33458fda985f
Author: Shahar Epstein <[email protected]>
AuthorDate: Fri Mar 15 20:58:08 2024 +0200

    Add AutoML templating tests (#38170)
---
 .../google/cloud/operators/test_automl.py          | 286 +++++++++++++++++++++
 1 file changed, 286 insertions(+)

diff --git a/tests/providers/google/cloud/operators/test_automl.py 
b/tests/providers/google/cloud/operators/test_automl.py
index 2c9872f445..4f00f76a2d 100644
--- a/tests/providers/google/cloud/operators/test_automl.py
+++ b/tests/providers/google/cloud/operators/test_automl.py
@@ -20,6 +20,7 @@ from __future__ import annotations
 import copy
 from unittest import mock
 
+import pytest
 from google.api_core.gapic_v1.method import DEFAULT
 from google.cloud.automl_v1beta1 import BatchPredictResult, Dataset, Model, 
PredictResponse
 
@@ -39,6 +40,7 @@ from airflow.providers.google.cloud.operators.automl import (
     AutoMLTablesUpdateDatasetOperator,
     AutoMLTrainModelOperator,
 )
+from airflow.utils import timezone
 
 CREDENTIALS = "test-creds"
 TASK_ID = "test-automl-hook"
@@ -88,6 +90,25 @@ class TestAutoMLTrainModelOperator:
             metadata=(),
         )
 
+    @pytest.mark.db_test
+    def test_templating(self, create_task_instance_of_operator):
+        ti = create_task_instance_of_operator(
+            AutoMLTrainModelOperator,
+            # Templated fields
+            model="{{ 'model' }}",
+            location="{{ 'location' }}",
+            impersonation_chain="{{ 'impersonation_chain' }}",
+            # Other parameters
+            dag_id="test_template_body_templating_dag",
+            task_id="test_template_body_templating_task",
+            execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+        )
+        ti.render_templates()
+        task: AutoMLTrainModelOperator = ti.task
+        assert task.model == "model"
+        assert task.location == "location"
+        assert task.impersonation_chain == "impersonation_chain"
+
 
 class TestAutoMLBatchPredictOperator:
     
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
@@ -118,6 +139,31 @@ class TestAutoMLBatchPredictOperator:
             timeout=None,
         )
 
+    @pytest.mark.db_test
+    def test_templating(self, create_task_instance_of_operator):
+        ti = create_task_instance_of_operator(
+            AutoMLBatchPredictOperator,
+            # Templated fields
+            model_id="{{ 'model' }}",
+            input_config="{{ 'input-config' }}",
+            output_config="{{ 'output-config' }}",
+            location="{{ 'location' }}",
+            project_id="{{ 'project-id' }}",
+            impersonation_chain="{{ 'impersonation-chain' }}",
+            # Other parameters
+            dag_id="test_template_body_templating_dag",
+            task_id="test_template_body_templating_task",
+            execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+        )
+        ti.render_templates()
+        task: AutoMLBatchPredictOperator = ti.task
+        assert task.model_id == "model"
+        assert task.input_config == "input-config"
+        assert task.output_config == "output-config"
+        assert task.location == "location"
+        assert task.project_id == "project-id"
+        assert task.impersonation_chain == "impersonation-chain"
+
 
 class TestAutoMLPredictOperator:
     
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
@@ -144,6 +190,28 @@ class TestAutoMLPredictOperator:
             timeout=None,
         )
 
+    @pytest.mark.db_test
+    def test_templating(self, create_task_instance_of_operator):
+        ti = create_task_instance_of_operator(
+            AutoMLPredictOperator,
+            # Templated fields
+            model_id="{{ 'model-id' }}",
+            location="{{ 'location' }}",
+            project_id="{{ 'project-id' }}",
+            impersonation_chain="{{ 'impersonation-chain' }}",
+            # Other parameters
+            dag_id="test_template_body_templating_dag",
+            task_id="test_template_body_templating_task",
+            execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+            payload={},
+        )
+        ti.render_templates()
+        task: AutoMLPredictOperator = ti.task
+        assert task.model_id == "model-id"
+        assert task.project_id == "project-id"
+        assert task.location == "location"
+        assert task.impersonation_chain == "impersonation-chain"
+
 
 class TestAutoMLCreateImportOperator:
     
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
@@ -167,6 +235,27 @@ class TestAutoMLCreateImportOperator:
             timeout=None,
         )
 
+    @pytest.mark.db_test
+    def test_templating(self, create_task_instance_of_operator):
+        ti = create_task_instance_of_operator(
+            AutoMLCreateDatasetOperator,
+            # Templated fields
+            dataset="{{ 'dataset' }}",
+            location="{{ 'location' }}",
+            project_id="{{ 'project-id' }}",
+            impersonation_chain="{{ 'impersonation-chain' }}",
+            # Other parameters
+            dag_id="test_template_body_templating_dag",
+            task_id="test_template_body_templating_task",
+            execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+        )
+        ti.render_templates()
+        task: AutoMLCreateDatasetOperator = ti.task
+        assert task.dataset == "dataset"
+        assert task.project_id == "project-id"
+        assert task.location == "location"
+        assert task.impersonation_chain == "impersonation-chain"
+
 
 class TestAutoMLListColumnsSpecsOperator:
     
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
@@ -199,6 +288,33 @@ class TestAutoMLListColumnsSpecsOperator:
             timeout=None,
         )
 
+    @pytest.mark.db_test
+    def test_templating(self, create_task_instance_of_operator):
+        ti = create_task_instance_of_operator(
+            AutoMLTablesListColumnSpecsOperator,
+            # Templated fields
+            dataset_id="{{ 'dataset-id' }}",
+            table_spec_id="{{ 'table-spec-id' }}",
+            field_mask="{{ 'field-mask' }}",
+            filter_="{{ 'filter-' }}",
+            location="{{ 'location' }}",
+            project_id="{{ 'project-id' }}",
+            impersonation_chain="{{ 'impersonation-chain' }}",
+            # Other parameters
+            dag_id="test_template_body_templating_dag",
+            task_id="test_template_body_templating_task",
+            execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+        )
+        ti.render_templates()
+        task: AutoMLTablesListColumnSpecsOperator = ti.task
+        assert task.dataset_id == "dataset-id"
+        assert task.table_spec_id == "table-spec-id"
+        assert task.field_mask == "field-mask"
+        assert task.filter_ == "filter-"
+        assert task.location == "location"
+        assert task.project_id == "project-id"
+        assert task.impersonation_chain == "impersonation-chain"
+
 
 class TestAutoMLUpdateDatasetOperator:
     
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
@@ -223,6 +339,27 @@ class TestAutoMLUpdateDatasetOperator:
             update_mask=MASK,
         )
 
+    @pytest.mark.db_test
+    def test_templating(self, create_task_instance_of_operator):
+        ti = create_task_instance_of_operator(
+            AutoMLTablesUpdateDatasetOperator,
+            # Templated fields
+            dataset="{{ 'dataset' }}",
+            update_mask="{{ 'update-mask' }}",
+            location="{{ 'location' }}",
+            impersonation_chain="{{ 'impersonation-chain' }}",
+            # Other parameters
+            dag_id="test_template_body_templating_dag",
+            task_id="test_template_body_templating_task",
+            execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+        )
+        ti.render_templates()
+        task: AutoMLTablesUpdateDatasetOperator = ti.task
+        assert task.dataset == "dataset"
+        assert task.update_mask == "update-mask"
+        assert task.location == "location"
+        assert task.impersonation_chain == "impersonation-chain"
+
 
 class TestAutoMLGetModelOperator:
     
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
@@ -246,6 +383,27 @@ class TestAutoMLGetModelOperator:
             timeout=None,
         )
 
+    @pytest.mark.db_test
+    def test_templating(self, create_task_instance_of_operator):
+        ti = create_task_instance_of_operator(
+            AutoMLGetModelOperator,
+            # Templated fields
+            model_id="{{ 'model-id' }}",
+            location="{{ 'location' }}",
+            project_id="{{ 'project-id' }}",
+            impersonation_chain="{{ 'impersonation-chain' }}",
+            # Other parameters
+            dag_id="test_template_body_templating_dag",
+            task_id="test_template_body_templating_task",
+            execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+        )
+        ti.render_templates()
+        task: AutoMLGetModelOperator = ti.task
+        assert task.model_id == "model-id"
+        assert task.location == "location"
+        assert task.project_id == "project-id"
+        assert task.impersonation_chain == "impersonation-chain"
+
 
 class TestAutoMLDeleteModelOperator:
     
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
@@ -266,6 +424,27 @@ class TestAutoMLDeleteModelOperator:
             timeout=None,
         )
 
+    @pytest.mark.db_test
+    def test_templating(self, create_task_instance_of_operator):
+        ti = create_task_instance_of_operator(
+            AutoMLDeleteModelOperator,
+            # Templated fields
+            model_id="{{ 'model-id' }}",
+            location="{{ 'location' }}",
+            project_id="{{ 'project-id' }}",
+            impersonation_chain="{{ 'impersonation-chain' }}",
+            # Other parameters
+            dag_id="test_template_body_templating_dag",
+            task_id="test_template_body_templating_task",
+            execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+        )
+        ti.render_templates()
+        task: AutoMLDeleteModelOperator = ti.task
+        assert task.model_id == "model-id"
+        assert task.location == "location"
+        assert task.project_id == "project-id"
+        assert task.impersonation_chain == "impersonation-chain"
+
 
 class TestAutoMLDeployModelOperator:
     
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
@@ -289,6 +468,27 @@ class TestAutoMLDeployModelOperator:
             timeout=None,
         )
 
+    @pytest.mark.db_test
+    def test_templating(self, create_task_instance_of_operator):
+        ti = create_task_instance_of_operator(
+            AutoMLDeployModelOperator,
+            # Templated fields
+            model_id="{{ 'model-id' }}",
+            location="{{ 'location' }}",
+            project_id="{{ 'project-id' }}",
+            impersonation_chain="{{ 'impersonation-chain' }}",
+            # Other parameters
+            dag_id="test_template_body_templating_dag",
+            task_id="test_template_body_templating_task",
+            execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+        )
+        ti.render_templates()
+        task: AutoMLDeployModelOperator = ti.task
+        assert task.model_id == "model-id"
+        assert task.location == "location"
+        assert task.project_id == "project-id"
+        assert task.impersonation_chain == "impersonation-chain"
+
 
 class TestAutoMLDatasetImportOperator:
     
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
@@ -311,6 +511,29 @@ class TestAutoMLDatasetImportOperator:
             timeout=None,
         )
 
+    @pytest.mark.db_test
+    def test_templating(self, create_task_instance_of_operator):
+        ti = create_task_instance_of_operator(
+            AutoMLImportDataOperator,
+            # Templated fields
+            dataset_id="{{ 'dataset-id' }}",
+            input_config="{{ 'input-config' }}",
+            location="{{ 'location' }}",
+            project_id="{{ 'project-id' }}",
+            impersonation_chain="{{ 'impersonation-chain' }}",
+            # Other parameters
+            dag_id="test_template_body_templating_dag",
+            task_id="test_template_body_templating_task",
+            execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+        )
+        ti.render_templates()
+        task: AutoMLImportDataOperator = ti.task
+        assert task.dataset_id == "dataset-id"
+        assert task.input_config == "input-config"
+        assert task.location == "location"
+        assert task.project_id == "project-id"
+        assert task.impersonation_chain == "impersonation-chain"
+
 
 class TestAutoMLTablesListTableSpecsOperator:
     
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
@@ -338,6 +561,29 @@ class TestAutoMLTablesListTableSpecsOperator:
             timeout=None,
         )
 
+    @pytest.mark.db_test
+    def test_templating(self, create_task_instance_of_operator):
+        ti = create_task_instance_of_operator(
+            AutoMLTablesListTableSpecsOperator,
+            # Templated fields
+            dataset_id="{{ 'dataset-id' }}",
+            filter_="{{ 'filter-' }}",
+            location="{{ 'location' }}",
+            project_id="{{ 'project-id' }}",
+            impersonation_chain="{{ 'impersonation-chain' }}",
+            # Other parameters
+            dag_id="test_template_body_templating_dag",
+            task_id="test_template_body_templating_task",
+            execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+        )
+        ti.render_templates()
+        task: AutoMLTablesListTableSpecsOperator = ti.task
+        assert task.dataset_id == "dataset-id"
+        assert task.filter_ == "filter-"
+        assert task.location == "location"
+        assert task.project_id == "project-id"
+        assert task.impersonation_chain == "impersonation-chain"
+
 
 class TestAutoMLDatasetListOperator:
     
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
@@ -352,6 +598,25 @@ class TestAutoMLDatasetListOperator:
             timeout=None,
         )
 
+    @pytest.mark.db_test
+    def test_templating(self, create_task_instance_of_operator):
+        ti = create_task_instance_of_operator(
+            AutoMLListDatasetOperator,
+            # Templated fields
+            location="{{ 'location' }}",
+            project_id="{{ 'project-id' }}",
+            impersonation_chain="{{ 'impersonation-chain' }}",
+            # Other parameters
+            dag_id="test_template_body_templating_dag",
+            task_id="test_template_body_templating_task",
+            execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+        )
+        ti.render_templates()
+        task: AutoMLListDatasetOperator = ti.task
+        assert task.location == "location"
+        assert task.project_id == "project-id"
+        assert task.impersonation_chain == "impersonation-chain"
+
 
 class TestAutoMLDatasetDeleteOperator:
     
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
@@ -371,3 +636,24 @@ class TestAutoMLDatasetDeleteOperator:
             retry=DEFAULT,
             timeout=None,
         )
+
+    @pytest.mark.db_test
+    def test_templating(self, create_task_instance_of_operator):
+        ti = create_task_instance_of_operator(
+            AutoMLDeleteDatasetOperator,
+            # Templated fields
+            dataset_id="{{ 'dataset-id' }}",
+            location="{{ 'location' }}",
+            project_id="{{ 'project-id' }}",
+            impersonation_chain="{{ 'impersonation-chain' }}",
+            # Other parameters
+            dag_id="test_template_body_templating_dag",
+            task_id="test_template_body_templating_task",
+            execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+        )
+        ti.render_templates()
+        task: AutoMLDeleteDatasetOperator = ti.task
+        assert task.dataset_id == "dataset-id"
+        assert task.location == "location"
+        assert task.project_id == "project-id"
+        assert task.impersonation_chain == "impersonation-chain"

Reply via email to