vincbeck commented on code in PR #41511:
URL: https://github.com/apache/airflow/pull/41511#discussion_r1718775008
##########
airflow/providers/amazon/aws/operators/athena.py:
##########
@@ -126,63 +130,97 @@ def execute(self, context: Context) -> str | None:
self.query_execution_context["Catalog"] = self.catalog
if self.output_location:
self.result_configuration["OutputLocation"] = self.output_location
- self.query_execution_id = self.hook.run_query(
- self.query,
- self.query_execution_context,
- self.result_configuration,
- self.client_request_token,
- self.workgroup,
- )
- AthenaQueryResultsLink.persist(
- context=context,
- operator=self,
- region_name=self.hook.conn_region_name,
- aws_partition=self.hook.conn_partition,
- query_execution_id=self.query_execution_id,
- )
- if self.deferrable:
- self.defer(
- trigger=AthenaTrigger(
- query_execution_id=self.query_execution_id,
- waiter_delay=self.sleep_time,
- waiter_max_attempts=self.max_polling_attempts,
- aws_conn_id=self.aws_conn_id,
- region_name=self.region_name,
- verify=self.verify,
- botocore_config=self.botocore_config,
- ),
- method_name="execute_complete",
+ if isinstance(self.query, str):
+ if self.split_statements:
+ query_list = self._split_sql_string(self.query)
+ else:
+ query_list = [self.query] if self.query.strip() else []
+ else:
+ query_list = self.query
+
+ query_list_len = len(query_list)
+
+ if not query_list_len:
+ raise AirflowException("No queries were found to execute.")
+
+ for query in query_list:
+ self.query_execution_id = self.hook.run_query(
+ query,
+ self.query_execution_context,
+ self.result_configuration,
+ self.client_request_token,
+ self.workgroup,
)
- # implicit else:
- query_status = self.hook.poll_query_status(
- self.query_execution_id,
- max_polling_attempts=self.max_polling_attempts,
- sleep_time=self.sleep_time,
- )
-
- if query_status in AthenaHook.FAILURE_STATES:
- error_message =
self.hook.get_state_change_reason(self.query_execution_id)
- raise AirflowException(
- f"Final state of Athena job is {query_status},
query_execution_id is "
- f"{self.query_execution_id}. Error: {error_message}"
+ AthenaQueryResultsLink.persist(
+ context=context,
+ operator=self,
+ region_name=self.hook.conn_region_name,
+ aws_partition=self.hook.conn_partition,
+ query_execution_id=self.query_execution_id,
)
- elif not query_status or query_status in
AthenaHook.INTERMEDIATE_STATES:
- raise AirflowException(
- f"Final state of Athena job is {query_status}. Max tries of
poll status exceeded, "
- f"query_execution_id is {self.query_execution_id}."
+
+ if self.deferrable:
+ self.defer(
+ trigger=AthenaTrigger(
+ query_execution_id=self.query_execution_id,
+ waiter_delay=self.sleep_time,
+ waiter_max_attempts=self.max_polling_attempts,
+ aws_conn_id=self.aws_conn_id,
+ region_name=self.region_name,
+ verify=self.verify,
+ botocore_config=self.botocore_config,
+ ),
+ kwargs={"query_list": query_list},
+ method_name="execute_next_query",
+ )
+ # implicit else:
+ query_status = self.hook.poll_query_status(
+ self.query_execution_id,
+ max_polling_attempts=self.max_polling_attempts,
+ sleep_time=self.sleep_time,
)
+ if query_status in AthenaHook.FAILURE_STATES:
+ error_message =
self.hook.get_state_change_reason(self.query_execution_id)
+ raise AirflowException(
+ f"Final state of Athena job is {query_status},
query_execution_id is "
+ f"{self.query_execution_id}. Error: {error_message}"
+ )
+ elif not query_status or query_status in
AthenaHook.INTERMEDIATE_STATES:
+ raise AirflowException(
+ f"Final state of Athena job is {query_status}. Max tries
of poll status exceeded, "
+ f"query_execution_id is {self.query_execution_id}."
+ )
+
return self.query_execution_id
- def execute_complete(self, context: Context, event: dict[str, Any] | None
= None) -> str:
+ def execute_next_query(
Review Comment:
Instead of running queries one by one and waiting each time each query
before running the other I would do it differently. Plus, this is unusual to
have task running -> task deferred -> task running task deferred -> etc ... I
am not saying this is wrong but I am wondering if it is not anti pattern. Why
dont you run all queries and then you wait for all queries to be done? You'd
have to update the trigger and the poll logic but this is doable
##########
airflow/providers/amazon/aws/operators/athena.py:
##########
@@ -99,6 +101,7 @@ def __init__(
log_query: bool = True,
deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
catalog: str = "AwsDataCatalog",
+ split_statements: bool = False,
Review Comment:
Please add it to doc string
##########
airflow/providers/amazon/aws/operators/athena.py:
##########
@@ -87,7 +89,7 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
def __init__(
self,
*,
- query: str,
+ query: str | list[str],
Review Comment:
Please update the docstring
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]