ferruzzi commented on code in PR #63035:
URL: https://github.com/apache/airflow/pull/63035#discussion_r2933182418


##########
providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py:
##########
@@ -411,7 +467,7 @@ def process_queue(self, queue_url: str):
                 task_key = self.running_tasks[ser_task_key]
             except KeyError:

Review Comment:
   Agreed on all points.



##########
providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py:
##########
@@ -205,55 +213,97 @@ def sync(self):
             self.log.exception("An error occurred while syncing tasks")
 
     def queue_workload(self, workload: workloads.All, session: Session | None) 
-> None:
-        from airflow.executors import workloads
 
-        if not isinstance(workload, workloads.ExecuteTask):
-            raise RuntimeError(f"{type(self)} cannot handle workloads of type 
{type(workload)}")
-        ti = workload.ti
-        self.queued_tasks[ti.key] = workload
+        if isinstance(workload, workloads.ExecuteTask):
+            ti = workload.ti
+            self.queued_tasks[ti.key] = workload
+            return
+
+        if AIRFLOW_V_3_2_PLUS and isinstance(workload, 
workloads.ExecuteCallback):
+            self.queued_callbacks[workload.callback.id] = workload
+            return
+
+        raise RuntimeError(f"{type(self)} cannot handle workloads of type 
{type(workload)}")
+
+    def _process_workloads(self, workload_items: Sequence[workloads.All]) -> 
None:
+
+        for w in workload_items:
+            key: TaskInstanceKey | str
+            command: list[workloads.All]
+            queue: str | None
+            if isinstance(w, workloads.ExecuteTask):
+                command = [w]
+                key = w.ti.key
+                queue = w.ti.queue
+                executor_config = w.ti.executor_config or {}
+
+                del self.queued_tasks[key]
 
-    def _process_workloads(self, workloads: Sequence[workloads.All]) -> None:
-        from airflow.executors.workloads import ExecuteTask
+                self.execute_async(
+                    key=key,
+                    command=command,
+                    queue=queue,
+                    executor_config=executor_config,
+                )
+
+                self.running.add(key)
+                continue
 
-        for w in workloads:
-            if not isinstance(w, ExecuteTask):
-                raise RuntimeError(f"{type(self)} cannot handle workloads of 
type {type(w)}")
+            if AIRFLOW_V_3_2_PLUS and isinstance(w, workloads.ExecuteCallback):
+                command = [w]
+                key = w.callback.id
+                queue = None
 
-            command = [w]
-            key = w.ti.key
-            queue = w.ti.queue
-            executor_config = w.ti.executor_config or {}
+                if isinstance(w.callback.data, dict) and "queue" in 
w.callback.data:
+                    queue = w.callback.data["queue"]
 
-            del self.queued_tasks[key]
-            self.execute_async(key=key, command=command, queue=queue, 
executor_config=executor_config)  # type: ignore[arg-type]
-            self.running.add(key)
+                del self.queued_callbacks[key]
 
-    def execute_async(self, key: TaskInstanceKey, command: CommandType, 
queue=None, executor_config=None):
+                self.execute_async(
+                    key=key,
+                    command=command,
+                    queue=queue,
+                )
+
+                self.running.add(key)
+                continue
+
+            raise RuntimeError(f"{type(self)} cannot handle workloads of type 
{type(w)}")
+
+    def execute_async(
+        self,
+        key: TaskInstanceKey | str,
+        command: CommandType | Sequence[workloads.All],
+        queue=None,
+        executor_config=None,
+    ):
         """
-        Save the task to be executed in the next sync by inserting the 
commands into a queue.
+        Save the workload to be executed in the next sync by inserting the 
commands into a queue.
 
-        :param key: A unique task key (typically a tuple identifying the task 
instance).
+        :param key: Unique workload key. Task workloads use TaskInstanceKey, 
callback workloads use a string id.
         :param command: The shell command string to execute.
         :param executor_config:  (Unused) to keep the same signature as the 
base.
         :param queue: (Unused) to keep the same signature as the base.
         """
         if len(command) == 1:
-            from airflow.executors.workloads import ExecuteTask
-
-            if isinstance(command[0], ExecuteTask):
-                workload = command[0]
-                ser_input = workload.model_dump_json()
-                command = [
-                    "python",
-                    "-m",
-                    "airflow.sdk.execution_time.execute_workload",
-                    "--json-string",
-                    ser_input,
-                ]
+            if AIRFLOW_V_3_2_PLUS:
+                if not isinstance(command[0], (workloads.ExecuteTask, 
workloads.ExecuteCallback)):
+                    raise RuntimeError(f"{type(self)} cannot handle workloads 
of type {type(command[0])}")
             else:
-                raise RuntimeError(
-                    f"LambdaExecutor doesn't know how to handle workload of 
type: {type(command[0])}"
-                )
+                if not isinstance(command[0], workloads.ExecuteTask):
+                    raise RuntimeError(f"{type(self)} cannot handle workloads 
of type {type(command[0])}")
+
+            workload = command[0]

Review Comment:
   Yup, that's great.  I didn't mean literally first line, just up here where 
it can be resued throughout. 👍 



##########
providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py:
##########
@@ -202,90 +212,144 @@ def sync(self):
                     "AWS credentials are either missing or expired: 
%s.\nRetrying connection", error
                 )
         except Exception:
-            self.log.exception("An error occurred while syncing tasks")
+            self.log.exception("An error occurred while syncing workloads.")
 
     def queue_workload(self, workload: workloads.All, session: Session | None) 
-> None:
         from airflow.executors import workloads
 
-        if not isinstance(workload, workloads.ExecuteTask):
-            raise RuntimeError(f"{type(self)} cannot handle workloads of type 
{type(workload)}")
-        ti = workload.ti
-        self.queued_tasks[ti.key] = workload
+        if isinstance(workload, workloads.ExecuteTask):
+            ti = workload.ti
+            self.queued_tasks[ti.key] = workload
+            return
+
+        if AIRFLOW_V_3_2_PLUS and isinstance(workload, 
workloads.ExecuteCallback):
+            self.queued_callbacks[workload.callback.id] = workload
+            return
+
+        raise RuntimeError(f"{type(self)} cannot handle workloads of type 
{type(workload)}")
+
+    def _process_workloads(self, workload_items: Sequence[workloads.All]) -> 
None:
+        from airflow.executors import workloads
+

Review Comment:
   I agree with Niko.  I know this was existing code, but that was a miss in an 
earlier review, `w` is a terrible variable name.  Can you fix it while you are 
in here please?



##########
providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py:
##########
@@ -49,29 +49,39 @@
     from airflow.executors import workloads
     from airflow.models.taskinstance import TaskInstance
 
+    if AIRFLOW_V_3_2_PLUS:
+        from airflow.executors.workloads.types import WorkloadKey as 
_WorkloadKey
+
+        WorkloadKey: TypeAlias = _WorkloadKey
+    else:
+        WorkloadKey: TypeAlias = TaskInstanceKey  # type: ignore[no-redef, 
misc]

Review Comment:
   For my own learning, what's the `misc` error we're ignoring here?  no-redef 
makes sense.



##########
providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py:
##########
@@ -202,90 +212,144 @@ def sync(self):
                     "AWS credentials are either missing or expired: 
%s.\nRetrying connection", error
                 )
         except Exception:
-            self.log.exception("An error occurred while syncing tasks")
+            self.log.exception("An error occurred while syncing workloads.")
 
     def queue_workload(self, workload: workloads.All, session: Session | None) 
-> None:
         from airflow.executors import workloads
 
-        if not isinstance(workload, workloads.ExecuteTask):
-            raise RuntimeError(f"{type(self)} cannot handle workloads of type 
{type(workload)}")
-        ti = workload.ti
-        self.queued_tasks[ti.key] = workload
+        if isinstance(workload, workloads.ExecuteTask):
+            ti = workload.ti
+            self.queued_tasks[ti.key] = workload
+            return
+
+        if AIRFLOW_V_3_2_PLUS and isinstance(workload, 
workloads.ExecuteCallback):
+            self.queued_callbacks[workload.callback.id] = workload
+            return
+
+        raise RuntimeError(f"{type(self)} cannot handle workloads of type 
{type(workload)}")
+
+    def _process_workloads(self, workload_items: Sequence[workloads.All]) -> 
None:
+        from airflow.executors import workloads
+
+        for w in workload_items:
+            queue: str | None
+            key: WorkloadKey
+            command: CommandType
+            if isinstance(w, workloads.ExecuteTask):
+                command = [w]
+                key = w.ti.key
+                queue = w.ti.queue
+                executor_config = w.ti.executor_config or {}
+
+                del self.queued_tasks[key]
+
+                self.execute_async(
+                    key=key,
+                    command=command,
+                    queue=queue,
+                    executor_config=executor_config,
+                )
+
+                self.running.add(key)
+                continue
+
+            if AIRFLOW_V_3_2_PLUS and isinstance(w, workloads.ExecuteCallback):
+                command = [w]
+                key = w.callback.id
+                queue = None
 
-    def _process_workloads(self, workloads: Sequence[workloads.All]) -> None:
-        from airflow.executors.workloads import ExecuteTask
+                if isinstance(w.callback.data, dict) and "queue" in 
w.callback.data:
+                    queue = w.callback.data["queue"]
 
-        for w in workloads:
-            if not isinstance(w, ExecuteTask):
-                raise RuntimeError(f"{type(self)} cannot handle workloads of 
type {type(w)}")
+                del self.queued_callbacks[key]
 
-            command = [w]
-            key = w.ti.key
-            queue = w.ti.queue
-            executor_config = w.ti.executor_config or {}
+                self.execute_async(
+                    key=key,
+                    command=command,
+                    queue=queue,
+                )
 
-            del self.queued_tasks[key]
-            self.execute_async(key=key, command=command, queue=queue, 
executor_config=executor_config)  # type: ignore[arg-type]
-            self.running.add(key)
+                self.running.add(key)
+                continue
 
-    def execute_async(self, key: TaskInstanceKey, command: CommandType, 
queue=None, executor_config=None):
+            raise RuntimeError(f"{type(self)} cannot handle workloads of type 
{type(w)}")
+
+    def execute_async(
+        self,
+        key: WorkloadKey,
+        command: CommandType,
+        queue=None,
+        executor_config=None,
+    ):
         """
-        Save the task to be executed in the next sync by inserting the 
commands into a queue.
+        Save the workload to be executed in the next sync by inserting the 
commands into a queue.
 
-        :param key: A unique task key (typically a tuple identifying the task 
instance).
-        :param command: The shell command string to execute.
+        :param key: Unique workload key. Task workloads use TaskInstanceKey, 
callback workloads use a string id.
+        :param command: The workload command or serialized shell command to 
execute.
         :param executor_config:  (Unused) to keep the same signature as the 
base.
         :param queue: (Unused) to keep the same signature as the base.
         """
         if len(command) == 1:
-            from airflow.executors.workloads import ExecuteTask
-
-            if isinstance(command[0], ExecuteTask):
-                workload = command[0]
-                ser_input = workload.model_dump_json()
-                command = [
-                    "python",
-                    "-m",
-                    "airflow.sdk.execution_time.execute_workload",
-                    "--json-string",
-                    ser_input,
-                ]
+            from airflow.executors import workloads
+
+            workload = command[0]
+
+            if AIRFLOW_V_3_2_PLUS:
+                if not isinstance(workload, (workloads.ExecuteTask, 
workloads.ExecuteCallback)):
+                    raise RuntimeError(f"{type(self)} cannot handle workloads 
of type {type(workload)}")
             else:
-                raise RuntimeError(
-                    f"LambdaExecutor doesn't know how to handle workload of 
type: {type(command[0])}"
-                )
+                if not isinstance(workload, workloads.ExecuteTask):
+                    raise RuntimeError(f"{type(self)} cannot handle workloads 
of type {type(workload)}")
 
-        self.pending_tasks.append(
+            ser_input = workload.model_dump_json()
+
+            command = [
+                "python",
+                "-m",
+                "airflow.sdk.execution_time.execute_workload",
+                "--json-string",
+                ser_input,
+            ]
+
+        self.pending_workloads.append(
             LambdaQueuedTask(
                 key, command, queue if queue else "", executor_config or {}, 
1, timezone.utcnow()
             )
         )
 
-    def attempt_task_runs(self):
+    def attempt_workload_runs(self):
         """
-        Attempt to run tasks that are queued in the pending_tasks.
+        Attempt to run workloads that are queued in the pending_workloads.
 
-        Each task is submitted to AWS Lambda with a payload containing the 
task key and command.
-        The task key is used to track the task's state in Airflow.
+        Each workload is submitted to AWS Lambda with a payload containing the 
workload key and command.
+        The workload key is used to track the workload's state in Airflow.
         """
-        queue_len = len(self.pending_tasks)
+        queue_len = len(self.pending_workloads)
         for _ in range(queue_len):
-            task_to_run = self.pending_tasks.popleft()
-            task_key = task_to_run.key
-            cmd = task_to_run.command
-            attempt_number = task_to_run.attempt_number
+            workload_to_run = self.pending_workloads.popleft()
+            workload_key = workload_to_run.key
+            cmd = workload_to_run.command
+            attempt_number = workload_to_run.attempt_number
             failure_reasons = []
-            ser_task_key = json.dumps(task_key._asdict())
+
+            try:
+                ser_workload_key = json.dumps(workload_key._asdict())
+            except AttributeError:
+                # Callback workloads use string id.
+                ser_workload_key = workload_key
+
             payload = {
-                "task_key": ser_task_key,
+                "task_key": ser_workload_key,

Review Comment:
   hrm.   I guess eventually we'll want to update that payload key too, but 
this may not be the time for it... that feels like a bigger breaking change



##########
providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py:
##########
@@ -313,59 +377,59 @@ def attempt_task_runs(self):
                 failure_reasons.append(str(e))
 
             if failure_reasons:
-                # Make sure the number of attempts does not exceed max invoke 
attempts
+                # Make sure the number of attempts does not exceed max invoke 
attempts.
                 if int(attempt_number) < int(self.max_invoke_attempts):
-                    task_to_run.attempt_number += 1
-                    task_to_run.next_attempt_time = timezone.utcnow() + 
calculate_next_attempt_delay(
+                    workload_to_run.attempt_number += 1
+                    workload_to_run.next_attempt_time = timezone.utcnow() + 
calculate_next_attempt_delay(
                         attempt_number
                     )
-                    self.pending_tasks.append(task_to_run)
+                    self.pending_workloads.append(workload_to_run)
                 else:
                     reasons_str = ", ".join(failure_reasons)
                     self.log.error(
                         "Lambda invoke %s has failed a maximum of %s times. 
Marking as failed. Reasons: %s",
-                        task_key,
+                        workload_key,
                         attempt_number,
                         reasons_str,
                     )
                     self.log_task_event(
                         event="lambda invoke failure",
-                        ti_key=task_key,
+                        ti_key=workload_key,

Review Comment:
   Non-blocking:  Looks like we missed this parameter name in base executor... 
we'll have to get that later.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to