Lee-W commented on code in PR #55068:
URL: https://github.com/apache/airflow/pull/55068#discussion_r2497833781


##########
airflow-core/src/airflow/models/dagbag.py:
##########
@@ -46,24 +46,29 @@ class DBDagBag:
     """
 
     def __init__(self, load_op_links: bool = True) -> None:
-        self._dags: dict[str, SerializedDAG] = {}  # dag_version_id to dag
+        self._dags: dict[str, SerializedDagModel] = {}  # dag_version_id to dag
         self.load_op_links = load_op_links
 
     def _read_dag(self, serdag: SerializedDagModel) -> SerializedDAG | None:
         serdag.load_op_links = self.load_op_links
         if dag := serdag.dag:
-            self._dags[serdag.dag_version_id] = dag
+            self._dags[serdag.dag_version_id] = serdag
         return dag
 
-    def _get_dag(self, version_id: str, session: Session) -> SerializedDAG | 
None:
-        if dag := self._dags.get(version_id):
-            return dag
-        dag_version = session.get(DagVersion, version_id, 
options=[joinedload(DagVersion.serialized_dag)])
-        if not dag_version:
-            return None
-        if not (serdag := dag_version.serialized_dag):
-            return None
-        return self._read_dag(serdag)
+    def get_dag_model(self, version_id: str, session: Session) -> 
SerializedDagModel | None:
+        if not (serdag := self._dags.get(version_id)):
+            dag_version = session.get(DagVersion, version_id, 
options=[joinedload(DagVersion.serialized_dag)])
+            if not dag_version:
+                return None
+            if not (serdag := dag_version.serialized_dag):
+                return None

Review Comment:
   ```suggestion
               if not dag_version or not (serdag := dag_version.serialized_dag):
                   return None
   ```



##########
airflow-core/src/airflow/models/taskinstance.py:
##########
@@ -1446,6 +1447,79 @@ def update_heartbeat(self):
                 .values(last_heartbeat_at=timezone.utcnow())
             )
 
+    def start_trigger_args(self) -> StartTriggerArgs | None:
+        if self.task:
+            if self.task.is_mapped:
+                context = self.get_template_context()
+                if self.task.expand_start_from_trigger(context=context):
+                    return self.task.expand_start_trigger_args(context=context)
+            elif self.task.start_from_trigger:
+                return self.task.start_trigger_args
+        return None
+
+    # TODO: We have some code duplication here and in the 
_create_ti_state_update_query_and_update_state
+    #       method of the task_instances module in the execution api when a 
TIDeferredStatePayload is being
+    #       processed.  This is because of a TaskInstance being updated 
differently using SQLAlchemy.

Review Comment:
   ```suggestion
       #       processed. This is because of a TaskInstance being updated 
differently using SQLAlchemy.
   ```



##########
providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py:
##########
@@ -130,6 +130,9 @@ def get_task_instance(self, session: Session) -> 
TaskInstance:
 
             :param session: Sqlalchemy session
             """
+            if not self.task_instance:
+                raise AirflowException(f"TaskInstance not set on 
{self.__class__.__name__}!")

Review Comment:
   ```suggestion
                   raise RuntimeError(f"TaskInstance not set on 
{self.__class__.__name__}!")
   ```



##########
providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py:
##########
@@ -150,23 +153,27 @@ def get_task_instance(self, session: Session) -> 
TaskInstance:
     async def get_task_state(self):
         from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
 
-        task_states_response = await 
sync_to_async(RuntimeTaskInstance.get_task_states)(
-            dag_id=self.task_instance.dag_id,
-            task_ids=[self.task_instance.task_id],
-            run_ids=[self.task_instance.run_id],
-            map_index=self.task_instance.map_index,
-        )
-        try:
-            task_state = 
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
-        except Exception:
-            raise AirflowException(
-                "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and 
map_index: %s is not found",
-                self.task_instance.dag_id,
-                self.task_instance.task_id,
-                self.task_instance.run_id,
-                self.task_instance.map_index,
+        if not self.task_instance:
+            raise AirflowException(f"TaskInstance not set on 
{self.__class__.__name__}!")
+
+        if not isinstance(self.task_instance, RuntimeTaskInstance):
+            task_states_response = await 
sync_to_async(RuntimeTaskInstance.get_task_states)(
+                dag_id=self.task_instance.dag_id,
+                task_ids=[self.task_instance.task_id],
+                run_ids=[self.task_instance.run_id],
+                map_index=self.task_instance.map_index,
             )
-        return task_state
+            try:
+                return 
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
+            except Exception:
+                raise AirflowException(

Review Comment:
   ```suggestion
                   raise RuntimeError(
   ```
   
   not sure whehter we can duplicate these parts by moving them to some base 
class or mix-in. but that can be done in a following PR



##########
providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py:
##########
@@ -266,6 +273,14 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
 
         @provide_session
         def get_task_instance(self, session: Session) -> TaskInstance:
+            """
+            Get the task instance for the current task.
+
+            :param session: Sqlalchemy session
+            """
+            if not self.task_instance:
+                raise AirflowException(f"TaskInstance not set on 
{self.__class__.__name__}!")

Review Comment:
   ```suggestion
                   raise RuntimeError(f"TaskInstance not set on 
{self.__class__.__name__}!")
   ```



##########
task-sdk/src/airflow/sdk/bases/operator.py:
##########
@@ -1411,6 +1415,15 @@ def _set_xcomargs_dependency(self, field: str, newvalue: 
Any) -> None:
             return
         XComArg.apply_upstream_relationship(self, newvalue)
 
+    def _validate_start_from_trigger_kwargs(self):
+        if self.start_from_trigger and self.start_trigger_args and 
self.start_trigger_args.trigger_kwargs:
+            for name, val in self.start_trigger_args.trigger_kwargs.items():
+                if callable(val):
+                    raise AirflowException(
+                        f"{self.__class__.__name__} with task_id 
'{self.task_id}' has a callable in trigger kwargs named "
+                        f"'{name}', which is not allowed when 
start_from_trigger is enabled."

Review Comment:
   I'm a bit confused here. Is it callable only? or anything not serialziable?



##########
providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py:
##########
@@ -114,7 +122,7 @@ def get_task_instance(self, session: Session) -> 
TaskInstance:
             task_instance = query.one_or_none()
             if task_instance is None:
                 raise AirflowException(
-                    "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and 
map_index: %s is not found",
+                    "TaskInstance with dag_id: %s,task_id: %s, run_id: %s and 
map_index: %s is not found",

Review Comment:
   ```suggestion
                       "TaskInstance with dag_id: %s, task_id: %s, run_id: %s 
and map_index: %s is not found",
   ```



##########
airflow-core/src/airflow/executors/workloads.py:
##########
@@ -200,6 +200,9 @@ class RunTrigger(BaseModel):
 
     type: Literal["RunTrigger"] = Field(init=False, default="RunTrigger")
 
+    dag_data: dict | None = None
+    """Serialized DAG model in dict format so it can be deserialized in 
trigger subprocess."""

Review Comment:
   ```suggestion
       """Serialized Dag model in dict format so it can be deserialized in 
trigger subprocess."""
   ```



##########
airflow-core/src/airflow/models/dagbag.py:
##########
@@ -46,24 +46,29 @@ class DBDagBag:
     """
 
     def __init__(self, load_op_links: bool = True) -> None:
-        self._dags: dict[str, SerializedDAG] = {}  # dag_version_id to dag
+        self._dags: dict[str, SerializedDagModel] = {}  # dag_version_id to dag
         self.load_op_links = load_op_links
 
     def _read_dag(self, serdag: SerializedDagModel) -> SerializedDAG | None:
         serdag.load_op_links = self.load_op_links
         if dag := serdag.dag:
-            self._dags[serdag.dag_version_id] = dag
+            self._dags[serdag.dag_version_id] = serdag
         return dag
 
-    def _get_dag(self, version_id: str, session: Session) -> SerializedDAG | 
None:
-        if dag := self._dags.get(version_id):
-            return dag
-        dag_version = session.get(DagVersion, version_id, 
options=[joinedload(DagVersion.serialized_dag)])
-        if not dag_version:
-            return None
-        if not (serdag := dag_version.serialized_dag):
-            return None
-        return self._read_dag(serdag)
+    def get_dag_model(self, version_id: str, session: Session) -> 
SerializedDagModel | None:
+        if not (serdag := self._dags.get(version_id)):
+            dag_version = session.get(DagVersion, version_id, 
options=[joinedload(DagVersion.serialized_dag)])
+            if not dag_version:
+                return None
+            if not (serdag := dag_version.serialized_dag):
+                return None
+        self._read_dag(serdag)
+        return serdag
+
+    def get_dag(self, version_id: str, session: Session) -> SerializedDAG | 
None:

Review Comment:
   should we call it `get_ser_dag` for clarity instead?



##########
providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py:
##########
@@ -105,6 +105,14 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
 
         @provide_session
         def get_task_instance(self, session: Session) -> TaskInstance:
+            """
+            Get the task instance for the current task.
+
+            :param session: Sqlalchemy session
+            """
+            if not self.task_instance:
+                raise AirflowException(f"TaskInstance not set on 
{self.__class__.__name__}!")

Review Comment:
   Let's not use AirflowException
   
   ```suggestion
                   raise RuRuntimeErrorn(f"TaskInstance not set on 
{self.__class__.__name__}!")
   ```



##########
airflow-core/src/airflow/models/taskinstance.py:
##########
@@ -1446,6 +1447,79 @@ def update_heartbeat(self):
                 .values(last_heartbeat_at=timezone.utcnow())
             )
 
+    def start_trigger_args(self) -> StartTriggerArgs | None:

Review Comment:
   ```suggestion
       @property
       def start_trigger_args(self) -> StartTriggerArgs | None:
   ```



##########
providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py:
##########
@@ -125,23 +133,27 @@ def get_task_instance(self, session: Session) -> 
TaskInstance:
     async def get_task_state(self):
         from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
 
-        task_states_response = await 
sync_to_async(RuntimeTaskInstance.get_task_states)(
-            dag_id=self.task_instance.dag_id,
-            task_ids=[self.task_instance.task_id],
-            run_ids=[self.task_instance.run_id],
-            map_index=self.task_instance.map_index,
-        )
-        try:
-            task_state = 
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
-        except Exception:
-            raise AirflowException(
-                "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and 
map_index: %s is not found",
-                self.task_instance.dag_id,
-                self.task_instance.task_id,
-                self.task_instance.run_id,
-                self.task_instance.map_index,
+        if not self.task_instance:
+            raise AirflowException(f"TaskInstance not set on 
{self.__class__.__name__}!")

Review Comment:
   ```suggestion
               raise RuntimeError(f"TaskInstance not set on 
{self.__class__.__name__}!")
   ```



##########
providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py:
##########
@@ -150,23 +153,27 @@ def get_task_instance(self, session: Session) -> 
TaskInstance:
     async def get_task_state(self):
         from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
 
-        task_states_response = await 
sync_to_async(RuntimeTaskInstance.get_task_states)(
-            dag_id=self.task_instance.dag_id,
-            task_ids=[self.task_instance.task_id],
-            run_ids=[self.task_instance.run_id],
-            map_index=self.task_instance.map_index,
-        )
-        try:
-            task_state = 
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
-        except Exception:
-            raise AirflowException(
-                "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and 
map_index: %s is not found",
-                self.task_instance.dag_id,
-                self.task_instance.task_id,
-                self.task_instance.run_id,
-                self.task_instance.map_index,
+        if not self.task_instance:
+            raise AirflowException(f"TaskInstance not set on 
{self.__class__.__name__}!")

Review Comment:
   ```suggestion
               raise RuntimeError(f"TaskInstance not set on 
{self.__class__.__name__}!")
   ```



##########
airflow-core/src/airflow/jobs/triggerer_job_runner.py:
##########
@@ -603,8 +606,49 @@ def update_triggers(self, requested_trigger_ids: set[int]):
         adds them to the deques so the subprocess can actually mutate the 
running
         trigger set.
         """
+        from airflow.models.dagbag import DBDagBag
+
+        dag_bag = DBDagBag()
         render_log_fname = log_filename_template_renderer()
 
+        @provide_session
+        def create_workload(trigger: Trigger, session: Session = NEW_SESSION) 
-> workloads.RunTrigger | None:
+            if trigger.task_instance:
+                if not new_trigger_orm.task_instance.dag_version_id:
+                    # This is to handle 2 to 3 upgrade where TI.dag_version_id 
can be none
+                    log.warning(
+                        "TaskInstance associated with Trigger has no 
associated Dag Version, skipping the trigger",
+                        ti_id=new_trigger_orm.task_instance.id,
+                    )
+                    return None

Review Comment:
   ```suggestion
               if trigger.task_instance and not 
new_trigger_orm.task_instance.dag_version_id:
                   # This is to handle 2 to 3 upgrade where TI.dag_version_id 
can be none
                   log.warning(
                       "TaskInstance associated with Trigger has no associated 
Dag Version, skipping the trigger",
                       ti_id=new_trigger_orm.task_instance.id,
                   )
                   return None
   ```



##########
providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py:
##########
@@ -125,23 +133,27 @@ def get_task_instance(self, session: Session) -> 
TaskInstance:
     async def get_task_state(self):
         from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
 
-        task_states_response = await 
sync_to_async(RuntimeTaskInstance.get_task_states)(
-            dag_id=self.task_instance.dag_id,
-            task_ids=[self.task_instance.task_id],
-            run_ids=[self.task_instance.run_id],
-            map_index=self.task_instance.map_index,
-        )
-        try:
-            task_state = 
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
-        except Exception:
-            raise AirflowException(
-                "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and 
map_index: %s is not found",
-                self.task_instance.dag_id,
-                self.task_instance.task_id,
-                self.task_instance.run_id,
-                self.task_instance.map_index,
+        if not self.task_instance:
+            raise AirflowException(f"TaskInstance not set on 
{self.__class__.__name__}!")
+
+        if not isinstance(self.task_instance, RuntimeTaskInstance):
+            task_states_response = await 
sync_to_async(RuntimeTaskInstance.get_task_states)(
+                dag_id=self.task_instance.dag_id,
+                task_ids=[self.task_instance.task_id],
+                run_ids=[self.task_instance.run_id],
+                map_index=self.task_instance.map_index,
             )
-        return task_state
+            try:
+                return 
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
+            except Exception:
+                raise AirflowException(

Review Comment:
   ```suggestion
                   raise RuntimeError(
   ```



##########
task-sdk/src/airflow/sdk/bases/operator.py:
##########
@@ -1411,6 +1415,15 @@ def _set_xcomargs_dependency(self, field: str, newvalue: 
Any) -> None:
             return
         XComArg.apply_upstream_relationship(self, newvalue)
 
+    def _validate_start_from_trigger_kwargs(self):
+        if self.start_from_trigger and self.start_trigger_args and 
self.start_trigger_args.trigger_kwargs:
+            for name, val in self.start_trigger_args.trigger_kwargs.items():
+                if callable(val):
+                    raise AirflowException(

Review Comment:
   ```suggestion
                       raise ValueError(
   ```



##########
providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py:
##########
@@ -286,23 +301,27 @@ def get_task_instance(self, session: Session) -> 
TaskInstance:
     async def get_task_state(self):
         from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
 
-        task_states_response = await 
sync_to_async(RuntimeTaskInstance.get_task_states)(
-            dag_id=self.task_instance.dag_id,
-            task_ids=[self.task_instance.task_id],
-            run_ids=[self.task_instance.run_id],
-            map_index=self.task_instance.map_index,
-        )
-        try:
-            task_state = 
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
-        except Exception:
-            raise AirflowException(
-                "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and 
map_index: %s is not found",
-                self.task_instance.dag_id,
-                self.task_instance.task_id,
-                self.task_instance.run_id,
-                self.task_instance.map_index,
+        if not self.task_instance:
+            raise AirflowException(f"TaskInstance not set on 
{self.__class__.__name__}!")
+
+        if not isinstance(self.task_instance, RuntimeTaskInstance):
+            task_states_response = await 
sync_to_async(RuntimeTaskInstance.get_task_states)(
+                dag_id=self.task_instance.dag_id,
+                task_ids=[self.task_instance.task_id],
+                run_ids=[self.task_instance.run_id],
+                map_index=self.task_instance.map_index,
             )
-        return task_state
+            try:
+                return 
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
+            except Exception:
+                raise AirflowException(

Review Comment:
   ```suggestion
                   raise RuntimeError(
   ```



##########
airflow-core/src/airflow/jobs/triggerer_job_runner.py:
##########
@@ -614,63 +658,46 @@ def update_triggers(self, requested_trigger_ids: 
set[int]):
         # Work out the two difference sets
         new_trigger_ids = requested_trigger_ids - known_trigger_ids
         cancel_trigger_ids = self.running_triggers - requested_trigger_ids
-        # Bulk-fetch new trigger records
-        new_triggers = Trigger.bulk_fetch(new_trigger_ids)
-        trigger_ids_with_non_task_associations = 
Trigger.fetch_trigger_ids_with_non_task_associations()
-        to_create: list[workloads.RunTrigger] = []
-        # Add in new triggers
-        for new_id in new_trigger_ids:
-            # Check it didn't vanish in the meantime
-            if new_id not in new_triggers:
-                log.warning("Trigger disappeared before we could start it", 
id=new_id)
-                continue
 
-            new_trigger_orm = new_triggers[new_id]
-
-            # If the trigger is not associated to a task, an asset, or a 
callback, this means the TaskInstance
-            # row was updated by either Trigger.submit_event or 
Trigger.submit_failure
-            # and can happen when a single trigger Job is being run on 
multiple TriggerRunners
-            # in a High-Availability setup.
-            if new_trigger_orm.task_instance is None and new_id not in 
trigger_ids_with_non_task_associations:
-                log.info(
-                    (
-                        "TaskInstance Trigger is None. It was likely updated 
by another trigger job. "
-                        "Skipping trigger instantiation."
-                    ),
-                    id=new_id,
-                )
-                continue
-
-            workload = workloads.RunTrigger(
-                classpath=new_trigger_orm.classpath,
-                id=new_id,
-                encrypted_kwargs=new_trigger_orm.encrypted_kwargs,
-                ti=None,
+        with create_session() as session:
+            # Bulk-fetch new trigger records
+            new_triggers = Trigger.bulk_fetch(new_trigger_ids, session=session)
+            trigger_ids_with_non_task_associations = 
Trigger.fetch_trigger_ids_with_non_task_associations(
+                session=session
             )
-            if new_trigger_orm.task_instance:
-                log_path = render_log_fname(ti=new_trigger_orm.task_instance)
-                if not new_trigger_orm.task_instance.dag_version_id:
-                    # This is to handle 2 to 3 upgrade where TI.dag_version_id 
can be none
-                    log.warning(
-                        "TaskInstance associated with Trigger has no 
associated Dag Version, skipping the trigger",
-                        ti_id=new_trigger_orm.task_instance.id,
+            to_create: list[workloads.RunTrigger] = []
+            # Add in new triggers
+            for new_id in new_trigger_ids:
+                # Check it didn't vanish in the meantime
+                if new_id not in new_triggers:
+                    log.warning("Trigger disappeared before we could start 
it", id=new_id)
+                    continue
+
+                new_trigger_orm = new_triggers[new_id]
+
+                # If the trigger is not associated to a task, an asset, or a 
callback, this means the TaskInstance
+                # row was updated by either Trigger.submit_event or 
Trigger.submit_failure
+                # and can happen when a single trigger Job is being run on 
multiple TriggerRunners
+                # in a High-Availability setup.
+                if (
+                    new_trigger_orm.task_instance is None
+                    and new_id not in trigger_ids_with_non_task_associations
+                ):
+                    log.info(
+                        (
+                            "TaskInstance Trigger is None. It was likely 
updated by another trigger job. "
+                            "Skipping trigger instantiation."
+                        ),
+                        id=new_id,
                     )
                     continue
-                ser_ti = workloads.TaskInstance.model_validate(
-                    new_trigger_orm.task_instance, from_attributes=True
-                )
-                # When producing logs from TIs, include the job id producing 
the logs to disambiguate it.
-                self.logger_cache[new_id] = TriggerLoggingFactory(
-                    log_path=f"{log_path}.trigger.{self.job.id}.log",
-                    ti=ser_ti,  # type: ignore
-                )
 
-                workload.ti = ser_ti
-                workload.timeout_after = 
new_trigger_orm.task_instance.trigger_timeout
+                workload = create_workload(new_trigger_orm, session=session)
 
-            to_create.append(workload)
+                if workload:
+                    to_create.append(workload)

Review Comment:
   ```suggestion
                   if (workload := create_workload(new_trigger_orm, 
session=session)):
                       to_create.append(workload)
   ```



##########
providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py:
##########
@@ -286,23 +301,27 @@ def get_task_instance(self, session: Session) -> 
TaskInstance:
     async def get_task_state(self):
         from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
 
-        task_states_response = await 
sync_to_async(RuntimeTaskInstance.get_task_states)(
-            dag_id=self.task_instance.dag_id,
-            task_ids=[self.task_instance.task_id],
-            run_ids=[self.task_instance.run_id],
-            map_index=self.task_instance.map_index,
-        )
-        try:
-            task_state = 
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
-        except Exception:
-            raise AirflowException(
-                "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and 
map_index: %s is not found",
-                self.task_instance.dag_id,
-                self.task_instance.task_id,
-                self.task_instance.run_id,
-                self.task_instance.map_index,
+        if not self.task_instance:
+            raise AirflowException(f"TaskInstance not set on 
{self.__class__.__name__}!")

Review Comment:
   ```suggestion
               raise RuntimeError(f"TaskInstance not set on 
{self.__class__.__name__}!")
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to