This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 88d5725c6b Add AutoML templating tests (#38170)
88d5725c6b is described below
commit 88d5725c6b24bc7d8a1c8201912d33458fda985f
Author: Shahar Epstein <[email protected]>
AuthorDate: Fri Mar 15 20:58:08 2024 +0200
Add AutoML templating tests (#38170)
---
.../google/cloud/operators/test_automl.py | 286 +++++++++++++++++++++
1 file changed, 286 insertions(+)
diff --git a/tests/providers/google/cloud/operators/test_automl.py
b/tests/providers/google/cloud/operators/test_automl.py
index 2c9872f445..4f00f76a2d 100644
--- a/tests/providers/google/cloud/operators/test_automl.py
+++ b/tests/providers/google/cloud/operators/test_automl.py
@@ -20,6 +20,7 @@ from __future__ import annotations
import copy
from unittest import mock
+import pytest
from google.api_core.gapic_v1.method import DEFAULT
from google.cloud.automl_v1beta1 import BatchPredictResult, Dataset, Model,
PredictResponse
@@ -39,6 +40,7 @@ from airflow.providers.google.cloud.operators.automl import (
AutoMLTablesUpdateDatasetOperator,
AutoMLTrainModelOperator,
)
+from airflow.utils import timezone
CREDENTIALS = "test-creds"
TASK_ID = "test-automl-hook"
@@ -88,6 +90,25 @@ class TestAutoMLTrainModelOperator:
metadata=(),
)
+ @pytest.mark.db_test
+ def test_templating(self, create_task_instance_of_operator):
+ ti = create_task_instance_of_operator(
+ AutoMLTrainModelOperator,
+ # Templated fields
+ model="{{ 'model' }}",
+ location="{{ 'location' }}",
+ impersonation_chain="{{ 'impersonation_chain' }}",
+ # Other parameters
+ dag_id="test_template_body_templating_dag",
+ task_id="test_template_body_templating_task",
+ execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+ )
+ ti.render_templates()
+ task: AutoMLTrainModelOperator = ti.task
+ assert task.model == "model"
+ assert task.location == "location"
+ assert task.impersonation_chain == "impersonation_chain"
+
class TestAutoMLBatchPredictOperator:
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
@@ -118,6 +139,31 @@ class TestAutoMLBatchPredictOperator:
timeout=None,
)
+ @pytest.mark.db_test
+ def test_templating(self, create_task_instance_of_operator):
+ ti = create_task_instance_of_operator(
+ AutoMLBatchPredictOperator,
+ # Templated fields
+ model_id="{{ 'model' }}",
+ input_config="{{ 'input-config' }}",
+ output_config="{{ 'output-config' }}",
+ location="{{ 'location' }}",
+ project_id="{{ 'project-id' }}",
+ impersonation_chain="{{ 'impersonation-chain' }}",
+ # Other parameters
+ dag_id="test_template_body_templating_dag",
+ task_id="test_template_body_templating_task",
+ execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+ )
+ ti.render_templates()
+ task: AutoMLBatchPredictOperator = ti.task
+ assert task.model_id == "model"
+ assert task.input_config == "input-config"
+ assert task.output_config == "output-config"
+ assert task.location == "location"
+ assert task.project_id == "project-id"
+ assert task.impersonation_chain == "impersonation-chain"
+
class TestAutoMLPredictOperator:
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
@@ -144,6 +190,28 @@ class TestAutoMLPredictOperator:
timeout=None,
)
+ @pytest.mark.db_test
+ def test_templating(self, create_task_instance_of_operator):
+ ti = create_task_instance_of_operator(
+ AutoMLPredictOperator,
+ # Templated fields
+ model_id="{{ 'model-id' }}",
+ location="{{ 'location' }}",
+ project_id="{{ 'project-id' }}",
+ impersonation_chain="{{ 'impersonation-chain' }}",
+ # Other parameters
+ dag_id="test_template_body_templating_dag",
+ task_id="test_template_body_templating_task",
+ execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+ payload={},
+ )
+ ti.render_templates()
+ task: AutoMLPredictOperator = ti.task
+ assert task.model_id == "model-id"
+ assert task.project_id == "project-id"
+ assert task.location == "location"
+ assert task.impersonation_chain == "impersonation-chain"
+
class TestAutoMLCreateImportOperator:
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
@@ -167,6 +235,27 @@ class TestAutoMLCreateImportOperator:
timeout=None,
)
+ @pytest.mark.db_test
+ def test_templating(self, create_task_instance_of_operator):
+ ti = create_task_instance_of_operator(
+ AutoMLCreateDatasetOperator,
+ # Templated fields
+ dataset="{{ 'dataset' }}",
+ location="{{ 'location' }}",
+ project_id="{{ 'project-id' }}",
+ impersonation_chain="{{ 'impersonation-chain' }}",
+ # Other parameters
+ dag_id="test_template_body_templating_dag",
+ task_id="test_template_body_templating_task",
+ execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+ )
+ ti.render_templates()
+ task: AutoMLCreateDatasetOperator = ti.task
+ assert task.dataset == "dataset"
+ assert task.project_id == "project-id"
+ assert task.location == "location"
+ assert task.impersonation_chain == "impersonation-chain"
+
class TestAutoMLListColumnsSpecsOperator:
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
@@ -199,6 +288,33 @@ class TestAutoMLListColumnsSpecsOperator:
timeout=None,
)
+ @pytest.mark.db_test
+ def test_templating(self, create_task_instance_of_operator):
+ ti = create_task_instance_of_operator(
+ AutoMLTablesListColumnSpecsOperator,
+ # Templated fields
+ dataset_id="{{ 'dataset-id' }}",
+ table_spec_id="{{ 'table-spec-id' }}",
+ field_mask="{{ 'field-mask' }}",
+ filter_="{{ 'filter-' }}",
+ location="{{ 'location' }}",
+ project_id="{{ 'project-id' }}",
+ impersonation_chain="{{ 'impersonation-chain' }}",
+ # Other parameters
+ dag_id="test_template_body_templating_dag",
+ task_id="test_template_body_templating_task",
+ execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+ )
+ ti.render_templates()
+ task: AutoMLTablesListColumnSpecsOperator = ti.task
+ assert task.dataset_id == "dataset-id"
+ assert task.table_spec_id == "table-spec-id"
+ assert task.field_mask == "field-mask"
+ assert task.filter_ == "filter-"
+ assert task.location == "location"
+ assert task.project_id == "project-id"
+ assert task.impersonation_chain == "impersonation-chain"
+
class TestAutoMLUpdateDatasetOperator:
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
@@ -223,6 +339,27 @@ class TestAutoMLUpdateDatasetOperator:
update_mask=MASK,
)
+ @pytest.mark.db_test
+ def test_templating(self, create_task_instance_of_operator):
+ ti = create_task_instance_of_operator(
+ AutoMLTablesUpdateDatasetOperator,
+ # Templated fields
+ dataset="{{ 'dataset' }}",
+ update_mask="{{ 'update-mask' }}",
+ location="{{ 'location' }}",
+ impersonation_chain="{{ 'impersonation-chain' }}",
+ # Other parameters
+ dag_id="test_template_body_templating_dag",
+ task_id="test_template_body_templating_task",
+ execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+ )
+ ti.render_templates()
+ task: AutoMLTablesUpdateDatasetOperator = ti.task
+ assert task.dataset == "dataset"
+ assert task.update_mask == "update-mask"
+ assert task.location == "location"
+ assert task.impersonation_chain == "impersonation-chain"
+
class TestAutoMLGetModelOperator:
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
@@ -246,6 +383,27 @@ class TestAutoMLGetModelOperator:
timeout=None,
)
+ @pytest.mark.db_test
+ def test_templating(self, create_task_instance_of_operator):
+ ti = create_task_instance_of_operator(
+ AutoMLGetModelOperator,
+ # Templated fields
+ model_id="{{ 'model-id' }}",
+ location="{{ 'location' }}",
+ project_id="{{ 'project-id' }}",
+ impersonation_chain="{{ 'impersonation-chain' }}",
+ # Other parameters
+ dag_id="test_template_body_templating_dag",
+ task_id="test_template_body_templating_task",
+ execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+ )
+ ti.render_templates()
+ task: AutoMLGetModelOperator = ti.task
+ assert task.model_id == "model-id"
+ assert task.location == "location"
+ assert task.project_id == "project-id"
+ assert task.impersonation_chain == "impersonation-chain"
+
class TestAutoMLDeleteModelOperator:
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
@@ -266,6 +424,27 @@ class TestAutoMLDeleteModelOperator:
timeout=None,
)
+ @pytest.mark.db_test
+ def test_templating(self, create_task_instance_of_operator):
+ ti = create_task_instance_of_operator(
+ AutoMLDeleteModelOperator,
+ # Templated fields
+ model_id="{{ 'model-id' }}",
+ location="{{ 'location' }}",
+ project_id="{{ 'project-id' }}",
+ impersonation_chain="{{ 'impersonation-chain' }}",
+ # Other parameters
+ dag_id="test_template_body_templating_dag",
+ task_id="test_template_body_templating_task",
+ execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+ )
+ ti.render_templates()
+ task: AutoMLDeleteModelOperator = ti.task
+ assert task.model_id == "model-id"
+ assert task.location == "location"
+ assert task.project_id == "project-id"
+ assert task.impersonation_chain == "impersonation-chain"
+
class TestAutoMLDeployModelOperator:
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
@@ -289,6 +468,27 @@ class TestAutoMLDeployModelOperator:
timeout=None,
)
+ @pytest.mark.db_test
+ def test_templating(self, create_task_instance_of_operator):
+ ti = create_task_instance_of_operator(
+ AutoMLDeployModelOperator,
+ # Templated fields
+ model_id="{{ 'model-id' }}",
+ location="{{ 'location' }}",
+ project_id="{{ 'project-id' }}",
+ impersonation_chain="{{ 'impersonation-chain' }}",
+ # Other parameters
+ dag_id="test_template_body_templating_dag",
+ task_id="test_template_body_templating_task",
+ execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+ )
+ ti.render_templates()
+ task: AutoMLDeployModelOperator = ti.task
+ assert task.model_id == "model-id"
+ assert task.location == "location"
+ assert task.project_id == "project-id"
+ assert task.impersonation_chain == "impersonation-chain"
+
class TestAutoMLDatasetImportOperator:
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
@@ -311,6 +511,29 @@ class TestAutoMLDatasetImportOperator:
timeout=None,
)
+ @pytest.mark.db_test
+ def test_templating(self, create_task_instance_of_operator):
+ ti = create_task_instance_of_operator(
+ AutoMLImportDataOperator,
+ # Templated fields
+ dataset_id="{{ 'dataset-id' }}",
+ input_config="{{ 'input-config' }}",
+ location="{{ 'location' }}",
+ project_id="{{ 'project-id' }}",
+ impersonation_chain="{{ 'impersonation-chain' }}",
+ # Other parameters
+ dag_id="test_template_body_templating_dag",
+ task_id="test_template_body_templating_task",
+ execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+ )
+ ti.render_templates()
+ task: AutoMLImportDataOperator = ti.task
+ assert task.dataset_id == "dataset-id"
+ assert task.input_config == "input-config"
+ assert task.location == "location"
+ assert task.project_id == "project-id"
+ assert task.impersonation_chain == "impersonation-chain"
+
class TestAutoMLTablesListTableSpecsOperator:
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
@@ -338,6 +561,29 @@ class TestAutoMLTablesListTableSpecsOperator:
timeout=None,
)
+ @pytest.mark.db_test
+ def test_templating(self, create_task_instance_of_operator):
+ ti = create_task_instance_of_operator(
+ AutoMLTablesListTableSpecsOperator,
+ # Templated fields
+ dataset_id="{{ 'dataset-id' }}",
+ filter_="{{ 'filter-' }}",
+ location="{{ 'location' }}",
+ project_id="{{ 'project-id' }}",
+ impersonation_chain="{{ 'impersonation-chain' }}",
+ # Other parameters
+ dag_id="test_template_body_templating_dag",
+ task_id="test_template_body_templating_task",
+ execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+ )
+ ti.render_templates()
+ task: AutoMLTablesListTableSpecsOperator = ti.task
+ assert task.dataset_id == "dataset-id"
+ assert task.filter_ == "filter-"
+ assert task.location == "location"
+ assert task.project_id == "project-id"
+ assert task.impersonation_chain == "impersonation-chain"
+
class TestAutoMLDatasetListOperator:
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
@@ -352,6 +598,25 @@ class TestAutoMLDatasetListOperator:
timeout=None,
)
+ @pytest.mark.db_test
+ def test_templating(self, create_task_instance_of_operator):
+ ti = create_task_instance_of_operator(
+ AutoMLListDatasetOperator,
+ # Templated fields
+ location="{{ 'location' }}",
+ project_id="{{ 'project-id' }}",
+ impersonation_chain="{{ 'impersonation-chain' }}",
+ # Other parameters
+ dag_id="test_template_body_templating_dag",
+ task_id="test_template_body_templating_task",
+ execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+ )
+ ti.render_templates()
+ task: AutoMLListDatasetOperator = ti.task
+ assert task.location == "location"
+ assert task.project_id == "project-id"
+ assert task.impersonation_chain == "impersonation-chain"
+
class TestAutoMLDatasetDeleteOperator:
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
@@ -371,3 +636,24 @@ class TestAutoMLDatasetDeleteOperator:
retry=DEFAULT,
timeout=None,
)
+
+ @pytest.mark.db_test
+ def test_templating(self, create_task_instance_of_operator):
+ ti = create_task_instance_of_operator(
+ AutoMLDeleteDatasetOperator,
+ # Templated fields
+ dataset_id="{{ 'dataset-id' }}",
+ location="{{ 'location' }}",
+ project_id="{{ 'project-id' }}",
+ impersonation_chain="{{ 'impersonation-chain' }}",
+ # Other parameters
+ dag_id="test_template_body_templating_dag",
+ task_id="test_template_body_templating_task",
+ execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
+ )
+ ti.render_templates()
+ task: AutoMLDeleteDatasetOperator = ti.task
+ assert task.dataset_id == "dataset-id"
+ assert task.location == "location"
+ assert task.project_id == "project-id"
+ assert task.impersonation_chain == "impersonation-chain"