This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f3f69bf1e0 Refactor GKECreateClusterOperator's body validation (#31923)
f3f69bf1e0 is described below
commit f3f69bf1e0c025d260be91daada04476d2418e9d
Author: max <[email protected]>
AuthorDate: Thu Jun 29 17:56:17 2023 +0200
Refactor GKECreateClusterOperator's body validation (#31923)
---
.../google/cloud/operators/kubernetes_engine.py | 90 ++++++++++++++--------
1 file changed, 60 insertions(+), 30 deletions(-)
diff --git a/airflow/providers/google/cloud/operators/kubernetes_engine.py
b/airflow/providers/google/cloud/operators/kubernetes_engine.py
index 043eda20ff..55971f52b7 100644
--- a/airflow/providers/google/cloud/operators/kubernetes_engine.py
+++ b/airflow/providers/google/cloud/operators/kubernetes_engine.py
@@ -20,7 +20,7 @@ from __future__ import annotations
import warnings
from functools import cached_property
-from typing import TYPE_CHECKING, Sequence
+from typing import TYPE_CHECKING, Any, Sequence
from google.api_core.exceptions import AlreadyExists
from google.cloud.container_v1.types import Cluster
@@ -268,41 +268,71 @@ class GKECreateClusterOperator(GoogleCloudBaseOperator):
self.impersonation_chain = impersonation_chain
self.poll_interval = poll_interval
self.deferrable = deferrable
- self._check_input()
+ self._validate_input()
self._hook: GKEHook | None = None
- def _check_input(self) -> None:
- if (
- not all([self.project_id, self.location, self.body])
- or (isinstance(self.body, dict) and "name" not in self.body)
- or (
- isinstance(self.body, dict)
- and ("initial_node_count" not in self.body and "node_pools"
not in self.body)
- )
- or (not (isinstance(self.body, dict)) and not (getattr(self.body,
"name", None)))
- or (
- not (isinstance(self.body, dict))
- and (
- not (getattr(self.body, "initial_node_count", None))
- and not (getattr(self.body, "node_pools", None))
+ def _validate_input(self) -> None:
+ """Primary validation of the input body."""
+ self._alert_deprecated_body_fields()
+
+ error_messages: list[str] = []
+ if not self._body_field("name"):
+ error_messages.append("Field body['name'] is missing or incorrect")
+
+ if self._body_field("initial_node_count"):
+ if self._body_field("node_pools"):
+ error_messages.append(
+ "Do not use filed body['initial_node_count'] and
body['node_pools'] at the same time."
)
- )
- ):
- self.log.error(
- "One of (project_id, location, body, body['name'], "
- "body['initial_node_count']), body['node_pools'] is missing or
incorrect"
- )
- raise AirflowException("Operator has incorrect or missing input.")
- elif (
- isinstance(self.body, dict) and ("initial_node_count" in self.body
and "node_pools" in self.body)
- ) or (
- not (isinstance(self.body, dict))
- and (getattr(self.body, "initial_node_count", None) and
getattr(self.body, "node_pools", None))
- ):
- self.log.error("Only one of body['initial_node_count']) and
body['node_pools'] may be specified")
+
+ if self._body_field("node_config"):
+ if self._body_field("node_pools"):
+ error_messages.append(
+ "Do not use filed body['node_config'] and
body['node_pools'] at the same time."
+ )
+
+ if self._body_field("node_pools"):
+ if any([self._body_field("node_config"),
self._body_field("initial_node_count")]):
+ error_messages.append(
+ "The field body['node_pools'] should not be set if "
+ "body['node_config'] or body['initial_code_count'] are
specified."
+ )
+
+ if not any([self._body_field("node_config"),
self._body_field("initial_node_count")]):
+ if not self._body_field("node_pools"):
+ error_messages.append(
+ "Field body['node_pools'] is required if none of fields "
+ "body['initial_node_count'] or body['node_pools'] are
specified."
+ )
+
+ for message in error_messages:
+ self.log.error(message)
+
+ if error_messages:
raise AirflowException("Operator has incorrect or missing input.")
+ def _body_field(self, field_name: str, default_value: Any = None) -> Any:
+ """Extracts the value of the given field name."""
+ if isinstance(self.body, dict):
+ return self.body.get(field_name, default_value)
+ else:
+ return getattr(self.body, field_name, default_value)
+
+ def _alert_deprecated_body_fields(self) -> None:
+ """Generates warning messages if deprecated fields were used in the
body."""
+ deprecated_body_fields_with_replacement = [
+ ("initial_node_count", "node_pool.initial_node_count"),
+ ("node_config", "node_pool.config"),
+ ("zone", "location"),
+ ("instance_group_urls", "node_pools.instance_group_urls"),
+ ]
+ for deprecated_field, replacement in
deprecated_body_fields_with_replacement:
+ if self._body_field(deprecated_field):
+ warnings.warn(
+ f"The body field '{deprecated_field}' is deprecated. Use
'{replacement}' instead."
+ )
+
def execute(self, context: Context) -> str:
hook = self._get_hook()
try: