kaxil commented on code in PR #55068:
URL: https://github.com/apache/airflow/pull/55068#discussion_r2995414896
##########
airflow-core/src/airflow/models/taskinstance.py:
##########
@@ -1590,6 +1590,73 @@ def update_heartbeat(self):
.values(last_heartbeat_at=timezone.utcnow())
)
+ @property
+ def start_trigger_args(self) -> StartTriggerArgs | None:
+ if self.task and self.task.start_from_trigger is True:
+ return self.task.start_trigger_args
+ return None
+
+ # TODO: We have some code duplication here and in the
_create_ti_state_update_query_and_update_state
+ # method of the task_instances module in the execution api when a
TIDeferredStatePayload is being
+ # processed. This is because of a TaskInstance being updated
differently using SQLAlchemy.
+ # If we use the approach from the execution api as common code in
the DagRun schedule_tis method,
+ # the side effect is the changes done to the task instance aren't
picked up by the scheduler and
+ # thus the task instance isn't processed until the scheduler is
restarted.
+ @provide_session
+ def defer_task(self, session: Session = NEW_SESSION) -> bool:
+ """
+ Mark the task as deferred and sets up the trigger that is needed to
resume it when TaskDeferred is raised.
+
+ :meta: private
+ """
+ from airflow.models.trigger import Trigger
+
+ if TYPE_CHECKING:
+ assert self.start_date
+ assert isinstance(self.task, Operator)
+
+ if start_trigger_args := self.start_trigger_args:
+ trigger_kwargs = start_trigger_args.trigger_kwargs or {}
+ timeout = start_trigger_args.timeout
+
+ # Calculate timeout too if it was passed
+ if timeout is not None:
+ self.trigger_timeout = timezone.utcnow() + timeout
+ else:
+ self.trigger_timeout = None
+
+ trigger_row = Trigger(
+ classpath=start_trigger_args.trigger_cls,
+ kwargs=trigger_kwargs,
+ )
+
+ # First, make the trigger entry
+ session.add(trigger_row)
+ session.flush()
+
+ # Then, update ourselves so it matches the deferral request
+ # Keep an eye on the logic in
`check_and_change_state_before_execution()`
+ # depending on self.next_method semantics
+ self.state = TaskInstanceState.DEFERRED
+ self.trigger_id = trigger_row.id
+ self.next_method = start_trigger_args.next_method
+ self.next_kwargs = start_trigger_args.next_kwargs or {}
+
+ # If an execution_timeout is set, set the timeout to the minimum of
+ # it and the trigger timeout
+ if execution_timeout := self.task.execution_timeout:
+ if self.trigger_timeout:
+ self.trigger_timeout = min(self.start_date +
execution_timeout, self.trigger_timeout)
+ else:
+ self.trigger_timeout = self.start_date + execution_timeout
+ self.start_date = timezone.utcnow()
+ if self.state != TaskInstanceState.UP_FOR_RESCHEDULE:
Review Comment:
This condition is always `True`. Line 1640 already sets `self.state =
TaskInstanceState.DEFERRED`, so by the time execution reaches this check, the
state is never `UP_FOR_RESCHEDULE`.
Result: `try_number` gets incremented for every `start_from_trigger`
deferral, including reschedule-mode sensors.
The original commented-out code in `dagrun.py` checked state *before*
calling `defer_task`. Fix: save the pre-deferral state before mutating it, or
move this check before line 1640.
##########
airflow-core/src/airflow/jobs/triggerer_job_runner.py:
##########
@@ -986,9 +1037,19 @@ async def init_comms(self):
raise RuntimeError(f"Required first message to be a
messages.StartTriggerer, it was {msg}")
async def create_triggers(self):
+ def create_runtime_ti(encoded_dag: dict) -> RuntimeTaskInstance:
+ task =
DagSerialization.from_dict(encoded_dag).get_task(workload.ti.task_id)
+
+ # I need to recreate a TaskInstance from task_runner before
invoking get_template_context (airflow.executors.workloads.TaskInstance)
+ return RuntimeTaskInstance.model_construct(
Review Comment:
**question**: `RuntimeTaskInstance.model_construct` bypasses Pydantic
validation. `RuntimeTaskInstance` requires `bundle_instance: BaseDagBundle` and
`start_date: AwareDatetime` as required fields (no defaults), so this creates
an instance missing those attributes.
Also, `_ti_context_from_server` defaults to `None`, so the context returned
by `get_template_context()` will be missing `ds`, `logical_date`, `dag_run`,
and most standard template variables -- only `task`, `dag`, `run_id`, `params`,
`var`, `conn` will be available.
Is this intentional? If so, it would help to document which template
variables are available when rendering in the triggerer, so operator authors
know the limitations.
##########
airflow-core/src/airflow/jobs/triggerer_job_runner.py:
##########
@@ -986,9 +1037,19 @@ async def init_comms(self):
raise RuntimeError(f"Required first message to be a
messages.StartTriggerer, it was {msg}")
async def create_triggers(self):
+ def create_runtime_ti(encoded_dag: dict) -> RuntimeTaskInstance:
+ task =
DagSerialization.from_dict(encoded_dag).get_task(workload.ti.task_id)
+
+ # I need to recreate a TaskInstance from task_runner before
invoking get_template_context (airflow.executors.workloads.TaskInstance)
+ return RuntimeTaskInstance.model_construct(
+ **workload.ti.model_dump(exclude_unset=True),
+ task=task,
+ )
+
"""Drain the to_create queue and create all new triggers that have
been requested in the DB."""
Review Comment:
This docstring is displaced below the inner `create_runtime_ti` function
definition. Python treats the first expression after `def` as the docstring, so
`create_triggers.__doc__` will be `None` (since `def create_runtime_ti` is a
statement, not a string literal). This string is silently evaluated and
discarded.
Move it immediately after `async def create_triggers(self):` (before the
inner function definition).
##########
airflow-core/src/airflow/jobs/triggerer_job_runner.py:
##########
@@ -658,6 +662,65 @@ def emit_metrics(self):
extra_tags={"hostname": self.job.hostname},
)
+ def _create_workload(
+ self,
+ trigger: Trigger,
+ dag_bag: DBDagBag,
+ render_log_fname: Callable[..., str],
+ session: Session,
+ ) -> workloads.RunTrigger | None:
+ if trigger.task_instance is None:
+ return workloads.RunTrigger(
+ id=trigger.id,
+ classpath=trigger.classpath,
+ encrypted_kwargs=trigger.encrypted_kwargs,
+ )
+
+ if not trigger.task_instance.dag_version_id:
+ # This is to handle 2 to 3 upgrade where TI.dag_version_id can be
none
+ log.warning(
+ "TaskInstance associated with Trigger has no associated Dag
Version, skipping the trigger",
+ ti_id=trigger.task_instance.id,
+ )
+ return None
+
+ log_path = render_log_fname(ti=trigger.task_instance)
+ ser_ti = TaskInstanceDTO.model_validate(trigger.task_instance,
from_attributes=True)
+
+ # When producing logs from TIs, include the job id producing the logs
to disambiguate it.
+ self.logger_cache[trigger.id] = TriggerLoggingFactory(
+ log_path=f"{log_path}.trigger.{self.job.id}.log",
+ ti=ser_ti, # type: ignore
+ )
+
+ serialized_dag_model = dag_bag.get_serialized_dag_model(
Review Comment:
`get_serialized_dag_model()` and the subsequent
`serialized_dag_model.dag.get_task()` deserialization runs for every trigger
with a task instance, not just ones with `start_from_trigger=True`. Most
triggers (deferred sensors, etc.) don't use `start_from_trigger`, so this adds
unnecessary DB and CPU overhead for the common case.
Consider adding a lightweight indicator (e.g., a boolean flag on the Trigger
model or TI) so you can skip the DAG load entirely when `start_from_trigger`
isn't in play.
##########
airflow-core/src/airflow/models/taskinstance.py:
##########
@@ -1590,6 +1590,73 @@ def update_heartbeat(self):
.values(last_heartbeat_at=timezone.utcnow())
)
+ @property
+ def start_trigger_args(self) -> StartTriggerArgs | None:
+ if self.task and self.task.start_from_trigger is True:
+ return self.task.start_trigger_args
+ return None
+
+ # TODO: We have some code duplication here and in the
_create_ti_state_update_query_and_update_state
+ # method of the task_instances module in the execution api when a
TIDeferredStatePayload is being
+ # processed. This is because of a TaskInstance being updated
differently using SQLAlchemy.
+ # If we use the approach from the execution api as common code in
the DagRun schedule_tis method,
+ # the side effect is the changes done to the task instance aren't
picked up by the scheduler and
+ # thus the task instance isn't processed until the scheduler is
restarted.
+ @provide_session
+ def defer_task(self, session: Session = NEW_SESSION) -> bool:
+ """
+ Mark the task as deferred and sets up the trigger that is needed to
resume it when TaskDeferred is raised.
+
+ :meta: private
+ """
+ from airflow.models.trigger import Trigger
+
+ if TYPE_CHECKING:
+ assert self.start_date
+ assert isinstance(self.task, Operator)
+
+ if start_trigger_args := self.start_trigger_args:
+ trigger_kwargs = start_trigger_args.trigger_kwargs or {}
+ timeout = start_trigger_args.timeout
+
+ # Calculate timeout too if it was passed
+ if timeout is not None:
+ self.trigger_timeout = timezone.utcnow() + timeout
+ else:
+ self.trigger_timeout = None
+
+ trigger_row = Trigger(
+ classpath=start_trigger_args.trigger_cls,
+ kwargs=trigger_kwargs,
+ )
+
+ # First, make the trigger entry
+ session.add(trigger_row)
+ session.flush()
+
+ # Then, update ourselves so it matches the deferral request
+ # Keep an eye on the logic in
`check_and_change_state_before_execution()`
+ # depending on self.next_method semantics
+ self.state = TaskInstanceState.DEFERRED
+ self.trigger_id = trigger_row.id
+ self.next_method = start_trigger_args.next_method
+ self.next_kwargs = start_trigger_args.next_kwargs or {}
+
+ # If an execution_timeout is set, set the timeout to the minimum of
+ # it and the trigger timeout
+ if execution_timeout := self.task.execution_timeout:
+ if self.trigger_timeout:
+ self.trigger_timeout = min(self.start_date +
execution_timeout, self.trigger_timeout)
Review Comment:
`self.start_date` is used here for the `execution_timeout` calculation, but
it's only assigned to `timezone.utcnow()` three lines later (line 1652). For a
newly created TI that hasn't started yet, `start_date` could be `None`, causing
`TypeError: unsupported operand type(s) for +: 'NoneType' and
'datetime.timedelta'`.
Move the `self.start_date = timezone.utcnow()` assignment before this block.
##########
airflow-core/src/airflow/jobs/triggerer_job_runner.py:
##########
@@ -986,9 +1037,19 @@ async def init_comms(self):
raise RuntimeError(f"Required first message to be a
messages.StartTriggerer, it was {msg}")
async def create_triggers(self):
+ def create_runtime_ti(encoded_dag: dict) -> RuntimeTaskInstance:
Review Comment:
`create_runtime_ti` captures `workload` from the enclosing scope rather than
accepting it as a parameter. This works today because it's called in the same
loop iteration, but if this code is ever refactored to defer the call (e.g.,
collected into a list), it would silently use the wrong workload. Safer to pass
it explicitly:
```python
def create_runtime_ti(encoded_dag: dict, ti: TaskInstanceDTO) ->
RuntimeTaskInstance:
task = DagSerialization.from_dict(encoded_dag).get_task(ti.task_id)
return RuntimeTaskInstance.model_construct(
**ti.model_dump(exclude_unset=True),
task=task,
)
```
##########
airflow-core/src/airflow/triggers/base.py:
##########
@@ -66,14 +79,56 @@ class BaseTrigger(abc.ABC, LoggingMixin):
supports_triggerer_queue: bool = True
def __init__(self, **kwargs):
+ super().__init__()
# these values are set by triggerer when preparing to run the instance
# when run, they are injected into logger record.
- self.task_instance = None
+ self._task_instance = None
self.trigger_id = None
+ self.template_fields = ()
+ self.template_ext = ()
+ self.task_id = None
def _set_context(self, context):
"""Part of LoggingMixin and used mainly for configuration of task
logging; not used for triggers."""
- raise NotImplementedError
+ pass
+
+ @property
+ def task(self) -> Operator | None:
+ # We must check if the TaskInstance is the generated Pydantic one or
the RuntimeTaskInstance
+ if self.task_instance and hasattr(self.task_instance, "task"):
+ return self.task_instance.task
+ return None
+
+ @property
+ def task_instance(self) -> TaskInstance:
+ return self._task_instance
+
+ @task_instance.setter
+ def task_instance(self, value: TaskInstance | None) -> None:
+ self._task_instance = value
+ if self.task_instance:
+ self.task_id = self.task_instance.task_id
+ if self.task:
+ self.template_fields = self.task.template_fields
+ self.template_ext = self.task.template_ext
+
+ def render_template_fields(
+ self,
+ context: Context,
+ jinja_env: jinja2.Environment | None = None,
+ ) -> None:
+ """
+ Template all attributes listed in *self.template_fields*.
+
+ This mutates the attributes in-place and is irreversible.
+
+ :param context: Context dict with values to apply on content.
+ :param jinja_env: Jinja's environment to use for rendering.
+ """
+ if not jinja_env:
+ jinja_env = self.get_template_env()
+ # We only need to render templated fields if templated fields are part
of the start_trigger_args
Review Comment:
`self._do_render_template_fields(self, self.template_fields, ...)` renders
attributes on `self` (the trigger) using the *operator's* `template_fields`
names. If an operator has `template_fields = ("bash_command", "env", "cwd")`
but the trigger only has a `command` attribute, `getattr(trigger,
"bash_command")` will raise `AttributeError`.
This only works when the trigger happens to have attributes with the exact
same names as the operator's template fields. Should this render only the
fields from `start_trigger_args.trigger_kwargs` instead?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]