amoghrajesh commented on code in PR #48513:
URL: https://github.com/apache/airflow/pull/48513#discussion_r2028240448


##########
task-sdk/src/airflow/sdk/execution_time/execute_workload.py:
##########
@@ -74,16 +71,44 @@ def main():
     parser = argparse.ArgumentParser(
         description="Execute a workload in a Containerised executor using the 
task SDK."
     )
-    parser.add_argument(
-        "input_file", help="Path to the input JSON file containing the 
execution workload payload."
+
+    # Create a mutually exclusive group to ensure that only one of the flags 
is set
+    group = parser.add_mutually_exclusive_group(required=True)
+    group.add_argument(
+        "--json-path",
+        help="Path to the input JSON file containing the execution workload 
payload.",
+        type=str,
+    )
+    group.add_argument(
+        "--json-string",
+        help="The JSON string itself containing the execution workload 
payload.",
+        type=str,

Review Comment:
   Yeah sounds fair



##########
providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py:
##########
@@ -462,6 +497,24 @@ def execute_async(self, key: TaskInstanceKey, command: 
CommandType, queue=None,
         """Save the task to be executed in the next sync by inserting the 
commands into a queue."""
         if executor_config and ("name" in executor_config or "command" in 
executor_config):
             raise ValueError('Executor Config should never override "name" or 
"command"')
+        if len(command) == 1:
+            from airflow.executors.workloads import ExecuteTask
+
+            if isinstance(command[0], ExecuteTask):
+                workload = command[0]
+                ser_input = workload.model_dump_json()
+                command = [
+                    "python",
+                    "-m",
+                    "airflow.sdk.execution_time.execute_workload",
+                    "--json-string",
+                    ser_input,
+                ]

Review Comment:
   I think this will show up in the logs? Are you ok with that? 



##########
providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/utils.py:
##########
@@ -61,7 +61,7 @@ class EcsQueuedTask:
 
     key: TaskInstanceKey
     command: CommandType
-    queue: str
+    queue: str | None

Review Comment:
   Can there be a None queue?



##########
task-sdk/src/airflow/sdk/execution_time/execute_workload.py:
##########
@@ -35,9 +35,7 @@
 log = structlog.get_logger(logger_name=__name__)
 
 
-def execute_workload(input: str) -> None:
-    from pydantic import TypeAdapter
-
+def execute_workload(workload) -> None:

Review Comment:
   Can we add typing for workload?



##########
task-sdk/src/airflow/sdk/execution_time/execute_workload.py:
##########
@@ -74,16 +71,44 @@ def main():
     parser = argparse.ArgumentParser(
         description="Execute a workload in a Containerised executor using the 
task SDK."
     )
-    parser.add_argument(
-        "input_file", help="Path to the input JSON file containing the 
execution workload payload."
+
+    # Create a mutually exclusive group to ensure that only one of the flags 
is set
+    group = parser.add_mutually_exclusive_group(required=True)
+    group.add_argument(
+        "--json-path",
+        help="Path to the input JSON file containing the execution workload 
payload.",
+        type=str,
+    )
+    group.add_argument(
+        "--json-string",
+        help="The JSON string itself containing the execution workload 
payload.",
+        type=str,
     )
 
     args = parser.parse_args()
 
-    with open(args.input_file) as file:
-        input_data = file.read()
+    from pydantic import TypeAdapter
 
-    execute_workload(input_data)
+    from airflow.executors import workloads
+
+    decoder = TypeAdapter[workloads.All](workloads.All)
+    if args.json_path:
+        try:
+            with open(args.json_path) as file:
+                input_data = file.read()
+                workload = decoder.validate_json(input_data)
+        except Exception as e:
+            log.error("Failed to read file", error=str(e))
+            sys.exit(1)
+
+    elif args.json_string:
+        try:
+            workload = decoder.validate_json(args.json_string)
+        except Exception as e:
+            log.error("Failed to parse input JSON string", error=str(e))
+            sys.exit(1)

Review Comment:
   A check if none are set?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to