This is an automated email from the ASF dual-hosted git repository.
weilee pushed a commit to branch v3-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/v3-1-test by this push:
new 8c8fb58ca7a [v3-1-test] fix: HITL params not validating (#57547)
(#58144)
8c8fb58ca7a is described below
commit 8c8fb58ca7a7726a31fa8760623b9347186322c8
Author: Wei Lee <[email protected]>
AuthorDate: Tue Nov 11 14:43:43 2025 +0800
[v3-1-test] fix: HITL params not validating (#57547) (#58144)
---
.../api_fastapi/core_api/datamodels/hitl.py | 18 ++-
airflow-core/src/airflow/ui/src/utils/hitl.ts | 51 ++++----
.../core_api/routes/public/test_hitl.py | 4 +-
.../src/tests_common/test_utils/version_compat.py | 1 +
.../airflow/providers/standard/operators/hitl.py | 19 +--
.../airflow/providers/standard/triggers/hitl.py | 135 ++++++++++++++++++++-
.../airflow/providers/standard/version_compat.py | 1 +
.../tests/unit/standard/operators/test_hitl.py | 133 +++++++++++++++++---
.../tests/unit/standard/triggers/test_hitl.py | 36 ++++--
task-sdk/src/airflow/sdk/api/client.py | 2 +-
task-sdk/src/airflow/sdk/definitions/param.py | 1 -
.../tests/task_sdk/execution_time/test_hitl.py | 4 +-
12 files changed, 336 insertions(+), 69 deletions(-)
diff --git a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/hitl.py
b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/hitl.py
index aa4f44f212b..4fa79eefe50 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/hitl.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/hitl.py
@@ -24,7 +24,6 @@ from pydantic import Field, field_validator
from airflow.api_fastapi.core_api.base import BaseModel
from airflow.api_fastapi.core_api.datamodels.task_instances import
TaskInstanceResponse
-from airflow.sdk import Param
class UpdateHITLDetailPayload(BaseModel):
@@ -61,7 +60,7 @@ class HITLDetail(BaseModel):
body: str | None = None
defaults: list[str] | None = None
multiple: bool = False
- params: dict[str, Any] = Field(default_factory=dict)
+ params: Mapping = Field(default_factory=dict)
assigned_users: list[HITLUser] = Field(default_factory=list)
created_at: datetime
@@ -77,7 +76,20 @@ class HITLDetail(BaseModel):
@classmethod
def get_params(cls, params: dict[str, Any]) -> dict[str, Any]:
"""Convert params attribute to dict representation."""
- return {k: v.dump() if isinstance(v, Param) else v for k, v in
params.items()}
+ return {
+ key: value
+ if HITLDetail._is_param(value)
+ else {
+ "value": value,
+ "description": None,
+ "schema": {},
+ }
+ for key, value in params.items()
+ }
+
+ @staticmethod
+ def _is_param(value: Any) -> bool:
+ return isinstance(value, dict) and all(key in value for key in
("description", "schema", "value"))
class HITLDetailCollection(BaseModel):
diff --git a/airflow-core/src/airflow/ui/src/utils/hitl.ts
b/airflow-core/src/airflow/ui/src/utils/hitl.ts
index 7d38a09374d..bd3a4051754 100644
--- a/airflow-core/src/airflow/ui/src/utils/hitl.ts
+++ b/airflow-core/src/airflow/ui/src/utils/hitl.ts
@@ -19,7 +19,7 @@
import type { TFunction } from "i18next";
import type { HITLDetail } from "openapi/requests/types.gen";
-import type { ParamsSpec } from "src/queries/useDagParams";
+import type { ParamSchema, ParamsSpec } from "src/queries/useDagParams";
export type HITLResponseParams = {
chosen_options?: Array<string>;
@@ -70,7 +70,7 @@ export const getHITLParamsDict = (
searchParams: URLSearchParams,
): ParamsSpec => {
const paramsDict: ParamsSpec = {};
- const { preloadedHITLOptions, preloadedHITLParams } =
getPreloadHITLFormData(searchParams, hitlDetail);
+ const { preloadedHITLOptions } = getPreloadHITLFormData(searchParams,
hitlDetail);
const isApprovalTask =
hitlDetail.options.includes("Approve") &&
hitlDetail.options.includes("Reject") &&
@@ -108,27 +108,36 @@ export const getHITLParamsDict = (
const sourceParams = hitlDetail.response_received ?
hitlDetail.params_input : hitlDetail.params;
Object.entries(sourceParams ?? {}).forEach(([key, value]) => {
- const valueType = typeof value === "number" ? "number" : "string";
+ if (!hitlDetail.params) {
+ return;
+ }
+ const paramData = hitlDetail.params[key] as ParamsSpec | undefined;
+
+ const description: string =
+ paramData && typeof paramData.description === "string" ?
paramData.description : "";
+
+ const schema: ParamSchema = {
+ const: undefined,
+ description_md: "",
+ enum: undefined,
+ examples: undefined,
+ format: undefined,
+ items: undefined,
+ maximum: undefined,
+ maxLength: undefined,
+ minimum: undefined,
+ minLength: undefined,
+ section: undefined,
+ title: key,
+ type: typeof value === "number" ? "number" : "string",
+ values_display: undefined,
+ ...(paramData?.schema && typeof paramData.schema === "object" ?
paramData.schema : {}),
+ };
paramsDict[key] = {
- description: "",
- schema: {
- const: undefined,
- description_md: "",
- enum: undefined,
- examples: undefined,
- format: undefined,
- items: undefined,
- maximum: undefined,
- maxLength: undefined,
- minimum: undefined,
- minLength: undefined,
- section: undefined,
- title: key,
- type: valueType,
- values_display: undefined,
- },
- value: preloadedHITLParams[key] ?? value,
+ description,
+ schema,
+ value: paramData?.value ?? value,
};
});
}
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_hitl.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_hitl.py
index 2429f2ea093..4c10feee7d6 100644
--- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_hitl.py
+++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_hitl.py
@@ -216,7 +216,7 @@ def expected_sample_hitl_detail_dict(sample_ti:
TaskInstance) -> dict[str, Any]:
"defaults": ["Approve"],
"multiple": False,
"options": ["Approve", "Reject"],
- "params": {"input_1": 1},
+ "params": {"input_1": {"value": 1, "schema": {}, "description": None}},
"assigned_users": [],
"created_at": mock.ANY,
"params_input": {},
@@ -619,7 +619,7 @@ class TestGetHITLDetailsEndpoint:
"body": "this is body 0",
"defaults": ["Approve"],
"multiple": False,
- "params": {"input_1": 1},
+ "params": {"input_1": {"value": 1, "schema": {},
"description": None}},
"assigned_users": [],
"created_at":
DEFAULT_CREATED_AT.isoformat().replace("+00:00", "Z"),
"responded_by_user": None,
diff --git a/devel-common/src/tests_common/test_utils/version_compat.py
b/devel-common/src/tests_common/test_utils/version_compat.py
index 50d50d2a1cc..25cb20a528c 100644
--- a/devel-common/src/tests_common/test_utils/version_compat.py
+++ b/devel-common/src/tests_common/test_utils/version_compat.py
@@ -36,6 +36,7 @@ AIRFLOW_V_3_0_1 = get_base_airflow_version_tuple() == (3, 0,
1)
AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0)
AIRFLOW_V_3_0_3_PLUS = get_base_airflow_version_tuple() >= (3, 0, 3)
AIRFLOW_V_3_1_PLUS = get_base_airflow_version_tuple() >= (3, 1, 0)
+AIRFLOW_V_3_1_3_PLUS = get_base_airflow_version_tuple() >= (3, 1, 3)
if AIRFLOW_V_3_1_PLUS:
diff --git
a/providers/standard/src/airflow/providers/standard/operators/hitl.py
b/providers/standard/src/airflow/providers/standard/operators/hitl.py
index f40abe1e467..4d7b258011c 100644
--- a/providers/standard/src/airflow/providers/standard/operators/hitl.py
+++ b/providers/standard/src/airflow/providers/standard/operators/hitl.py
@@ -19,7 +19,7 @@ from __future__ import annotations
import logging
from airflow.exceptions import AirflowOptionalProviderFeatureException
-from airflow.providers.standard.version_compat import AIRFLOW_V_3_1_PLUS
+from airflow.providers.standard.version_compat import AIRFLOW_V_3_1_3_PLUS,
AIRFLOW_V_3_1_PLUS
if not AIRFLOW_V_3_1_PLUS:
raise AirflowOptionalProviderFeatureException("Human in the loop
functionality needs Airflow 3.1+.")
@@ -84,6 +84,7 @@ class HITLOperator(BaseOperator):
self.multiple = multiple
self.params: ParamsDict = params if isinstance(params, ParamsDict)
else ParamsDict(params or {})
+
self.notifiers: Sequence[BaseNotifier] = (
[notifiers] if isinstance(notifiers, BaseNotifier) else notifiers
or []
)
@@ -110,6 +111,7 @@ class HITLOperator(BaseOperator):
Raises:
ValueError: If `"_options"` key is present in `params`, which is
not allowed.
"""
+ self.params.validate()
if "_options" in self.params:
raise ValueError('"_options" is not allowed in params')
@@ -165,8 +167,10 @@ class HITLOperator(BaseOperator):
)
@property
- def serialized_params(self) -> dict[str, Any]:
- return self.params.dump() if isinstance(self.params, ParamsDict) else
self.params
+ def serialized_params(self) -> dict[str, dict[str, Any]]:
+ if not AIRFLOW_V_3_1_3_PLUS:
+ return self.params.dump() if isinstance(self.params, ParamsDict)
else self.params
+ return {k: self.params.get_param(k).serialize() for k in self.params}
def execute_complete(self, context: Context, event: dict[str, Any]) -> Any:
if "error" in event:
@@ -196,13 +200,12 @@ class HITLOperator(BaseOperator):
def validate_params_input(self, params_input: Mapping) -> None:
"""Check whether user provide valid params input."""
- if (
- self.serialized_params is not None
- and params_input is not None
- and set(self.serialized_params.keys()) ^ set(params_input)
- ):
+ if self.params and params_input and set(self.serialized_params.keys())
^ set(params_input):
raise ValueError(f"params_input {params_input} does not match
params {self.params}")
+ for key, value in params_input.items():
+ self.params[key] = value
+
def generate_link_to_ui(
self,
*,
diff --git a/providers/standard/src/airflow/providers/standard/triggers/hitl.py
b/providers/standard/src/airflow/providers/standard/triggers/hitl.py
index b0b8ec71ef9..483cb0be874 100644
--- a/providers/standard/src/airflow/providers/standard/triggers/hitl.py
+++ b/providers/standard/src/airflow/providers/standard/triggers/hitl.py
@@ -30,6 +30,9 @@ from uuid import UUID
from asgiref.sync import sync_to_async
+from airflow.exceptions import ParamValidationError
+from airflow.sdk import Param
+from airflow.sdk.definitions.param import ParamsDict
from airflow.sdk.execution_time.hitl import (
HITLUser,
get_hitl_detail_content_detail,
@@ -43,7 +46,7 @@ class HITLTriggerEventSuccessPayload(TypedDict, total=False):
"""Minimum required keys for a success Human-in-the-loop TriggerEvent."""
chosen_options: list[str]
- params_input: dict[str, Any]
+ params_input: dict[str, dict[str, Any]]
responded_by_user: HITLUser | None
responded_at: datetime
timedout: bool
@@ -53,7 +56,7 @@ class HITLTriggerEventFailurePayload(TypedDict):
"""Minimum required keys for a failed Human-in-the-loop TriggerEvent."""
error: str
- error_type: Literal["timeout", "unknown"]
+ error_type: Literal["timeout", "unknown", "validation"]
class HITLTrigger(BaseTrigger):
@@ -64,7 +67,7 @@ class HITLTrigger(BaseTrigger):
*,
ti_id: UUID,
options: list[str],
- params: dict[str, Any],
+ params: dict[str, dict[str, Any]],
defaults: list[str] | None = None,
multiple: bool = False,
timeout_datetime: datetime | None,
@@ -80,7 +83,21 @@ class HITLTrigger(BaseTrigger):
self.defaults = defaults
self.timeout_datetime = timeout_datetime
- self.params = params
+ self.params = ParamsDict(
+ {
+ k: Param(
+ v.pop("value"),
+ **v,
+ )
+ if HITLTrigger._is_param(v)
+ else Param(v)
+ for k, v in params.items()
+ },
+ )
+
+ @staticmethod
+ def _is_param(value: Any) -> bool:
+ return isinstance(value, dict) and all(key in value for key in
("description", "schema", "value"))
def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize HITLTrigger arguments and classpath."""
@@ -90,13 +107,120 @@ class HITLTrigger(BaseTrigger):
"ti_id": self.ti_id,
"options": self.options,
"defaults": self.defaults,
- "params": self.params,
+ "params": {k: self.params.get_param(k).serialize() for k in
self.params},
"multiple": self.multiple,
"timeout_datetime": self.timeout_datetime,
"poke_interval": self.poke_interval,
},
)
+ async def _handle_timeout(self) -> TriggerEvent:
+ """Handle HITL timeout logic and yield appropriate event."""
+ resp = await
sync_to_async(get_hitl_detail_content_detail)(ti_id=self.ti_id)
+
+ # Case 1: Response arrived just before timeout
+ if resp.response_received and resp.chosen_options:
+ if TYPE_CHECKING:
+ assert resp.responded_by_user is not None
+ assert resp.responded_at is not None
+
+ chosen_options_list = list(resp.chosen_options or [])
+ self.log.info(
+ "[HITL] responded_by=%s (id=%s) options=%s at %s (timeout
fallback skipped)",
+ resp.responded_by_user.name,
+ resp.responded_by_user.id,
+ chosen_options_list,
+ resp.responded_at,
+ )
+ return TriggerEvent(
+ HITLTriggerEventSuccessPayload(
+ chosen_options=chosen_options_list,
+ params_input=resp.params_input or {},
+ responded_at=resp.responded_at,
+ responded_by_user=HITLUser(
+ id=resp.responded_by_user.id,
+ name=resp.responded_by_user.name,
+ ),
+ timedout=False,
+ )
+ )
+
+ # Case 2: No defaults defined → failure
+ if self.defaults is None:
+ return TriggerEvent(
+ HITLTriggerEventFailurePayload(
+ error="The timeout has passed, and the response has not
yet been received.",
+ error_type="timeout",
+ )
+ )
+
+ # Case 3: Timeout fallback to default
+ resp = await sync_to_async(update_hitl_detail_response)(
+ ti_id=self.ti_id,
+ chosen_options=self.defaults,
+ params_input=self.params.dump(),
+ )
+ if TYPE_CHECKING:
+ assert resp.responded_at is not None
+
+ self.log.info(
+ "[HITL] timeout reached before receiving response, fallback to
default %s",
+ self.defaults,
+ )
+ return TriggerEvent(
+ HITLTriggerEventSuccessPayload(
+ chosen_options=self.defaults,
+ params_input=self.params.dump(),
+ responded_by_user=None,
+ responded_at=resp.responded_at,
+ timedout=True,
+ )
+ )
+
+ async def _handle_response(self):
+ """Check if HITL response is ready and yield success if so."""
+ resp = await
sync_to_async(get_hitl_detail_content_detail)(ti_id=self.ti_id)
+ if TYPE_CHECKING:
+ assert resp.responded_by_user is not None
+ assert resp.responded_at is not None
+
+ if not (resp.response_received and resp.chosen_options):
+ return None
+
+ # validate input
+ if params_input := resp.params_input:
+ try:
+ for key, value in params_input.items():
+ self.params[key] = value
+ except ParamValidationError as err:
+ return TriggerEvent(
+ HITLTriggerEventFailurePayload(
+ error=str(err),
+ error_type="validation",
+ )
+ )
+
+ chosen_options_list = list(resp.chosen_options or [])
+ self.log.info(
+ "[HITL] responded_by=%s (id=%s) options=%s at %s",
+ resp.responded_by_user.name,
+ resp.responded_by_user.id,
+ chosen_options_list,
+ resp.responded_at,
+ )
+ return TriggerEvent(
+ HITLTriggerEventSuccessPayload(
+ chosen_options=chosen_options_list,
+ params_input=params_input or {},
+ responded_at=resp.responded_at,
+ responded_by_user=HITLUser(
+ id=resp.responded_by_user.id,
+ name=resp.responded_by_user.name,
+ ),
+ timedout=False,
+ )
+ )
+
async def run(self) -> AsyncIterator[TriggerEvent]:
"""Loop until the Human-in-the-loop response received or timeout
reached."""
while True:
@@ -185,4 +309,5 @@ class HITLTrigger(BaseTrigger):
)
)
return
+
await asyncio.sleep(self.poke_interval)
diff --git
a/providers/standard/src/airflow/providers/standard/version_compat.py
b/providers/standard/src/airflow/providers/standard/version_compat.py
index d36955c6776..71623adc03a 100644
--- a/providers/standard/src/airflow/providers/standard/version_compat.py
+++ b/providers/standard/src/airflow/providers/standard/version_compat.py
@@ -34,6 +34,7 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
AIRFLOW_V_3_0_PLUS: bool = get_base_airflow_version_tuple() >= (3, 0, 0)
AIRFLOW_V_3_1_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 0)
+AIRFLOW_V_3_1_3_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 3)
# BaseOperator is not imported from SDK from 3.0 (and only done from 3.1) due
to a bug with
# DecoratedOperator -- where `DecoratedOperator._handle_output` needed
`xcom_push` to exist on `BaseOperator`
diff --git a/providers/standard/tests/unit/standard/operators/test_hitl.py
b/providers/standard/tests/unit/standard/operators/test_hitl.py
index f198cc461cf..7a8c5b91bf1 100644
--- a/providers/standard/tests/unit/standard/operators/test_hitl.py
+++ b/providers/standard/tests/unit/standard/operators/test_hitl.py
@@ -18,8 +18,6 @@ from __future__ import annotations
import pytest
-from airflow.providers.standard.exceptions import HITLTimeoutError,
HITLTriggerEventError
-
from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS
if not AIRFLOW_V_3_1_PLUS:
@@ -33,9 +31,10 @@ from urllib.parse import parse_qs, urlparse
import pytest
from sqlalchemy import select
-from airflow.exceptions import AirflowException, DownstreamTasksSkipped
+from airflow.exceptions import AirflowException, DownstreamTasksSkipped,
ParamValidationError
from airflow.models import TaskInstance, Trigger
from airflow.models.hitl import HITLDetail
+from airflow.providers.standard.exceptions import HITLTimeoutError,
HITLTriggerEventError
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.providers.standard.operators.hitl import (
ApprovalOperator,
@@ -48,6 +47,7 @@ from airflow.sdk.definitions.param import ParamsDict
from airflow.sdk.execution_time.hitl import HITLUser
from tests_common.test_utils.config import conf_vars
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_3_PLUS
if TYPE_CHECKING:
from sqlalchemy.orm import Session
@@ -103,9 +103,27 @@ class TestHITLOperator:
params=ParamsDict({"input_1": 1}),
)
- def test_validate_params_with__options(self) -> None:
+ @pytest.mark.parametrize(
+ ("params", "exc", "error_msg"),
+ (
+ (ParamsDict({"_options": 1}), ValueError, '"_options" is not
allowed in params'),
+ (
+ ParamsDict({"param": Param("", type="integer")}),
+ ParamValidationError,
+ (
+ "Invalid input for param param: '' is not of type
'integer'\n\n"
+ "Failed validating 'type' in schema:\n"
+ " {'type': 'integer'}\n\n"
+ "On instance:\n ''"
+ ),
+ ),
+ ),
+ )
+ def test_validate_params(
+ self, params: ParamsDict, exc: type[ValueError |
ParamValidationError], error_msg: str
+ ) -> None:
# validate_params is called during initialization
- with pytest.raises(ValueError, match='"_options" is not allowed in
params'):
+ with pytest.raises(exc, match=error_msg):
HITLOperator(
task_id="hitl_test",
subject="This is subject",
@@ -113,7 +131,7 @@ class TestHITLOperator:
body="This is body",
defaults=["1"],
multiple=False,
- params=ParamsDict({"_options": 1}),
+ params=params,
)
def test_validate_defaults(self) -> None:
@@ -183,12 +201,21 @@ class TestHITLOperator:
assert hitl_detail_model.body == "This is body"
assert hitl_detail_model.defaults == ["1"]
assert hitl_detail_model.multiple is False
- assert hitl_detail_model.params == {"input_1": 1}
assert hitl_detail_model.assignees == [{"id": "test", "name": "test"}]
assert hitl_detail_model.responded_at is None
assert hitl_detail_model.responded_by is None
assert hitl_detail_model.chosen_options is None
assert hitl_detail_model.params_input == {}
+ if AIRFLOW_V_3_1_3_PLUS:
+ assert hitl_detail_model.params == {
+ "input_1": {
+ "value": 1,
+ "description": None,
+ "schema": {},
+ }
+ }
+ else:
+ assert hitl_detail_model.params == {"input_1": 1}
assert notifier.called is True
@@ -199,17 +226,55 @@ class TestHITLOperator:
"ti_id": ti.id,
"options": ["1", "2", "3", "4", "5"],
"defaults": ["1"],
- "params": {"input_1": 1},
+ "params": {
+ "input_1": {
+ "value": 1,
+ "description": None,
+ "schema": {},
+ }
+ },
"multiple": False,
"timeout_datetime": None,
"poke_interval": 5.0,
}
+ @pytest.mark.skipif(not AIRFLOW_V_3_1_3_PLUS, reason="This only works in
airflow-core >= 3.1.3")
@pytest.mark.parametrize(
"input_params, expected_params",
[
- (ParamsDict({"input": 1}), {"input": 1}),
- ({"input": Param(5, type="integer", minimum=3)}, {"input": 5}),
+ (
+ ParamsDict({"input": 1}),
+ {
+ "input": {
+ "description": None,
+ "schema": {},
+ "value": 1,
+ },
+ },
+ ),
+ (
+ {"input": Param(5, type="integer", minimum=3,
description="test")},
+ {
+ "input": {
+ "value": 5,
+ "schema": {
+ "minimum": 3,
+ "type": "integer",
+ },
+ "description": "test",
+ }
+ },
+ ),
+ (
+ {"input": 1},
+ {
+ "input": {
+ "value": 1,
+ "schema": {},
+ "description": None,
+ }
+ },
+ ),
(None, {}),
],
)
@@ -223,6 +288,20 @@ class TestHITLOperator:
)
assert hitl_op.serialized_params == expected_params
+ @pytest.mark.skipif(
+ AIRFLOW_V_3_1_3_PLUS,
+ reason="Preserve the old behavior if airflow-core < 3.1.3. Otherwise
the UI will break.",
+ )
+ def test_serialzed_params_legacy(self) -> None:
+ hitl_op = HITLOperator(
+ task_id="hitl_test",
+ subject="This is subject",
+ body="This is body",
+ options=["1", "2", "3", "4", "5"],
+ params={"input": Param(1)},
+ )
+ assert hitl_op.serialized_params == {"input": 1}
+
def test_execute_complete(self) -> None:
hitl_op = HITLOperator(
task_id="hitl_test",
@@ -292,21 +371,47 @@ class TestHITLOperator:
},
)
- def test_validate_params_input_with_invalid_input(self) -> None:
+ @pytest.mark.parametrize(
+ ("params", "params_input", "exc", "error_msg"),
+ (
+ (
+ ParamsDict({"input": 1}),
+ {"no such key": 2, "input": 333},
+ ValueError,
+ "params_input {'no such key': 2, 'input': 333} does not match
params {'input': 1}",
+ ),
+ (
+ ParamsDict({"input": Param(3, type="number", minimum=3)}),
+ {"input": 0},
+ ParamValidationError,
+ (
+ "Invalid input for param input: 0 is less than the minimum
of 3\n\n"
+ "Failed validating 'minimum' in schema:\n.*"
+ ),
+ ),
+ ),
+ )
+ def test_validate_params_input_with_invalid_input(
+ self,
+ params: ParamsDict,
+ params_input: dict[str, Any],
+ exc: type[ValueError | ParamValidationError],
+ error_msg: str,
+ ) -> None:
hitl_op = HITLOperator(
task_id="hitl_test",
subject="This is subject",
body="This is body",
options=["1", "2", "3", "4", "5"],
- params={"input": 1},
+ params=params,
)
- with pytest.raises(ValueError):
+ with pytest.raises(exc, match=error_msg):
hitl_op.execute_complete(
context={},
event={
"chosen_options": ["1"],
- "params_input": {"no such key": 2, "input": 333},
+ "params_input": params_input,
"responded_by_user": {"id": "test", "name": "test"},
},
)
diff --git a/providers/standard/tests/unit/standard/triggers/test_hitl.py
b/providers/standard/tests/unit/standard/triggers/test_hitl.py
index 1441d04bfc0..adb82ff00c0 100644
--- a/providers/standard/tests/unit/standard/triggers/test_hitl.py
+++ b/providers/standard/tests/unit/standard/triggers/test_hitl.py
@@ -17,6 +17,8 @@
from __future__ import annotations
+from typing import Any
+
import pytest
from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS
@@ -40,16 +42,22 @@ from airflow.providers.standard.triggers.hitl import (
from airflow.triggers.base import TriggerEvent
TI_ID = uuid7()
-default_trigger_args = {
- "ti_id": TI_ID,
- "options": ["1", "2", "3", "4", "5"],
- "params": {"input": 1},
- "multiple": False,
-}
+
+
[email protected]
+def default_trigger_args() -> dict[str, Any]:
+ return {
+ "ti_id": TI_ID,
+ "options": ["1", "2", "3", "4", "5"],
+ "params": {
+ "input": {"value": 1, "schema": {}, "description": None},
+ },
+ "multiple": False,
+ }
class TestHITLTrigger:
- def test_serialization(self):
+ def test_serialization(self, default_trigger_args):
trigger = HITLTrigger(
defaults=["1"],
timeout_datetime=None,
@@ -61,7 +69,7 @@ class TestHITLTrigger:
assert kwargs == {
"ti_id": TI_ID,
"options": ["1", "2", "3", "4", "5"],
- "params": {"input": 1},
+ "params": {"input": {"value": 1, "description": None, "schema":
{}}},
"defaults": ["1"],
"multiple": False,
"timeout_datetime": None,
@@ -71,7 +79,7 @@ class TestHITLTrigger:
@pytest.mark.db_test
@pytest.mark.asyncio
@mock.patch("airflow.sdk.execution_time.hitl.update_hitl_detail_response")
- async def test_run_failed_due_to_timeout(self, mock_update,
mock_supervisor_comms):
+ async def test_run_failed_due_to_timeout(self, mock_update,
mock_supervisor_comms, default_trigger_args):
trigger = HITLTrigger(
timeout_datetime=utcnow() + timedelta(seconds=0.1),
poke_interval=5,
@@ -100,7 +108,9 @@ class TestHITLTrigger:
@pytest.mark.asyncio
@mock.patch.object(HITLTrigger, "log")
@mock.patch("airflow.sdk.execution_time.hitl.update_hitl_detail_response")
- async def test_run_fallback_to_default_due_to_timeout(self, mock_update,
mock_log, mock_supervisor_comms):
+ async def test_run_fallback_to_default_due_to_timeout(
+ self, mock_update, mock_log, mock_supervisor_comms,
default_trigger_args
+ ):
trigger = HITLTrigger(
defaults=["1"],
timeout_datetime=utcnow() + timedelta(seconds=0.1),
@@ -139,7 +149,7 @@ class TestHITLTrigger:
@mock.patch.object(HITLTrigger, "log")
@mock.patch("airflow.sdk.execution_time.hitl.update_hitl_detail_response")
async def test_run_should_check_response_in_timeout_handler(
- self, mock_update, mock_log, mock_supervisor_comms
+ self, mock_update, mock_log, mock_supervisor_comms,
default_trigger_args
):
# action time only slightly before timeout
action_datetime = utcnow() + timedelta(seconds=0.1)
@@ -186,7 +196,9 @@ class TestHITLTrigger:
@pytest.mark.asyncio
@mock.patch.object(HITLTrigger, "log")
@mock.patch("airflow.sdk.execution_time.hitl.update_hitl_detail_response")
- async def test_run(self, mock_update, mock_log, mock_supervisor_comms,
time_machine):
+ async def test_run(
+ self, mock_update, mock_log, mock_supervisor_comms, time_machine,
default_trigger_args
+ ):
time_machine.move_to(datetime(2025, 7, 29, 2, 0, 0))
trigger = HITLTrigger(
diff --git a/task-sdk/src/airflow/sdk/api/client.py
b/task-sdk/src/airflow/sdk/api/client.py
index 2a83bb8d7da..1d1db8cfb07 100644
--- a/task-sdk/src/airflow/sdk/api/client.py
+++ b/task-sdk/src/airflow/sdk/api/client.py
@@ -728,7 +728,7 @@ class HITLOperations:
body: str | None = None,
defaults: list[str] | None = None,
multiple: bool = False,
- params: dict[str, Any] | None = None,
+ params: dict[str, dict[str, Any]] | None = None,
assigned_users: list[HITLUser] | None = None,
) -> HITLDetailRequest:
"""Add a Human-in-the-loop response that waits for human response for
a specific Task Instance."""
diff --git a/task-sdk/src/airflow/sdk/definitions/param.py
b/task-sdk/src/airflow/sdk/definitions/param.py
index 5da589d79e9..2c853ce1ffc 100644
--- a/task-sdk/src/airflow/sdk/definitions/param.py
+++ b/task-sdk/src/airflow/sdk/definitions/param.py
@@ -137,7 +137,6 @@ class ParamsDict(MutableMapping[str, Any]):
if they are not already. This class is to replace param's dictionary
implicitly
and ideally not needed to be used directly.
-
:param dict_obj: A dict or dict like object to init ParamsDict
:param suppress_exception: Flag to suppress value exceptions while
initializing the ParamsDict
"""
diff --git a/task-sdk/tests/task_sdk/execution_time/test_hitl.py
b/task-sdk/tests/task_sdk/execution_time/test_hitl.py
index cd3682d922e..5eb2dc7dab4 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_hitl.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_hitl.py
@@ -39,7 +39,7 @@ def test_upsert_hitl_detail(mock_supervisor_comms) -> None:
subject="Subject",
body="Optional body",
defaults=["Approve", "Reject"],
- params={"input_1": 1},
+ params={"input_1": {"value": 1, "description": None, "schema": {}}},
assigned_users=[HITLUser(id="test", name="test")],
multiple=False,
)
@@ -50,7 +50,7 @@ def test_upsert_hitl_detail(mock_supervisor_comms) -> None:
subject="Subject",
body="Optional body",
defaults=["Approve", "Reject"],
- params={"input_1": 1},
+ params={"input_1": {"value": 1, "description": None, "schema":
{}}},
assigned_users=[APIHITLUser(id="test", name="test")],
multiple=False,
)