flolas commented on code in PR #41511:
URL: https://github.com/apache/airflow/pull/41511#discussion_r1718786865


##########
airflow/providers/amazon/aws/operators/athena.py:
##########
@@ -126,63 +130,97 @@ def execute(self, context: Context) -> str | None:
         self.query_execution_context["Catalog"] = self.catalog
         if self.output_location:
             self.result_configuration["OutputLocation"] = self.output_location
-        self.query_execution_id = self.hook.run_query(
-            self.query,
-            self.query_execution_context,
-            self.result_configuration,
-            self.client_request_token,
-            self.workgroup,
-        )
-        AthenaQueryResultsLink.persist(
-            context=context,
-            operator=self,
-            region_name=self.hook.conn_region_name,
-            aws_partition=self.hook.conn_partition,
-            query_execution_id=self.query_execution_id,
-        )
 
-        if self.deferrable:
-            self.defer(
-                trigger=AthenaTrigger(
-                    query_execution_id=self.query_execution_id,
-                    waiter_delay=self.sleep_time,
-                    waiter_max_attempts=self.max_polling_attempts,
-                    aws_conn_id=self.aws_conn_id,
-                    region_name=self.region_name,
-                    verify=self.verify,
-                    botocore_config=self.botocore_config,
-                ),
-                method_name="execute_complete",
+        if isinstance(self.query, str):
+            if self.split_statements:
+                query_list = self._split_sql_string(self.query)
+            else:
+                query_list = [self.query] if self.query.strip() else []
+        else:
+            query_list = self.query
+
+        query_list_len = len(query_list)
+
+        if not query_list_len:
+            raise AirflowException("No queries were found to execute.")
+
+        for query in query_list:
+            self.query_execution_id = self.hook.run_query(
+                query,
+                self.query_execution_context,
+                self.result_configuration,
+                self.client_request_token,
+                self.workgroup,
             )
-        # implicit else:
-        query_status = self.hook.poll_query_status(
-            self.query_execution_id,
-            max_polling_attempts=self.max_polling_attempts,
-            sleep_time=self.sleep_time,
-        )
-
-        if query_status in AthenaHook.FAILURE_STATES:
-            error_message = 
self.hook.get_state_change_reason(self.query_execution_id)
-            raise AirflowException(
-                f"Final state of Athena job is {query_status}, 
query_execution_id is "
-                f"{self.query_execution_id}. Error: {error_message}"
+            AthenaQueryResultsLink.persist(
+                context=context,
+                operator=self,
+                region_name=self.hook.conn_region_name,
+                aws_partition=self.hook.conn_partition,
+                query_execution_id=self.query_execution_id,
             )
-        elif not query_status or query_status in 
AthenaHook.INTERMEDIATE_STATES:
-            raise AirflowException(
-                f"Final state of Athena job is {query_status}. Max tries of 
poll status exceeded, "
-                f"query_execution_id is {self.query_execution_id}."
+
+            if self.deferrable:
+                self.defer(
+                    trigger=AthenaTrigger(
+                        query_execution_id=self.query_execution_id,
+                        waiter_delay=self.sleep_time,
+                        waiter_max_attempts=self.max_polling_attempts,
+                        aws_conn_id=self.aws_conn_id,
+                        region_name=self.region_name,
+                        verify=self.verify,
+                        botocore_config=self.botocore_config,
+                    ),
+                    kwargs={"query_list": query_list},
+                    method_name="execute_next_query",
+                )
+            # implicit else:
+            query_status = self.hook.poll_query_status(
+                self.query_execution_id,
+                max_polling_attempts=self.max_polling_attempts,
+                sleep_time=self.sleep_time,
             )
 
+            if query_status in AthenaHook.FAILURE_STATES:
+                error_message = 
self.hook.get_state_change_reason(self.query_execution_id)
+                raise AirflowException(
+                    f"Final state of Athena job is {query_status}, 
query_execution_id is "
+                    f"{self.query_execution_id}. Error: {error_message}"
+                )
+            elif not query_status or query_status in 
AthenaHook.INTERMEDIATE_STATES:
+                raise AirflowException(
+                    f"Final state of Athena job is {query_status}. Max tries 
of poll status exceeded, "
+                    f"query_execution_id is {self.query_execution_id}."
+                )
+
         return self.query_execution_id
 
-    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> str:
+    def execute_next_query(

Review Comment:
   >Why dont you run all queries and then you wait for all queries to be done? 
   This feature is for queries that need to be executed sequentially because 
they depend on each other. For example, one query might drop a temporary table, 
and the next creates it again.
   
   for example:
   ```sql
   DROP TABLE temporal.transform_step;
   CREATE TABLE temporal.transform_step AS SELECT 1;
   ```
   Running them in parallel would cause conflicts. For parallel execution, 
using TaskFlow with `expand` would be a better approach.
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to