curiousjazz77 opened a new issue #8449: Create batch version of the 
AWSAthenaOperator
URL: https://github.com/apache/airflow/issues/8449
 
 
   **Description**
   
   Create batch version of the AWSAthenaOperator that can accept multiple 
queries and execute them
   
   **Use case / motivation**
   
   Currently, the AWSAthenaOperator is built to handle one query and poll for 
its success. In the event that you have multiple queries to execute via Athena, 
you must move logic into the dag to run the AWSAthenaOperator in a for loop 
over a number of queries. This is not best practice in the event that you have 
a task that generates batch queries to be submitted to Athena.
   
   This issue proposes that an AWSAthenaBatchOperator be created that can 
execute a batch of queries. This would allow Airflow users to contain logic to 
the tasks instead of the dags. 
   
   A first take on creating a new operator like this:
   ```
   class AWSAthenaBatchOperator(BaseOperator):
       """
       An operator that submit a batch of presto queries to athena for the same 
database.
       If ``do_xcom_push`` is True, the QueryExecutionID assigned to the
       query will be pushed to an XCom when it successfuly completes.
       :param query: Presto to be run on athena. (templated)
       :type queries: str demlinited by ";\n"
       :param database: Database to select. (templated)
       :type database: str
       :param output_location: s3 path to write the query results into. 
(templated)
       :type output_location: str
       :param aws_conn_id: aws connection to use
       :type aws_conn_id: str
       :param sleep_time: Time to wait between two consecutive call to check 
query status on athena
       :type sleep_time: int
       :param max_tries: Number of times to poll for query state before 
function exits
       :type max_triex: int
       """
   
       ui_color = '#44b5e2'
       template_fields = ('query', 'database', 'output_location')
       template_ext = ('.sql', )
   
       @apply_defaults
       def __init__(  # pylint: disable=too-many-arguments
           self,
           queries,
           database,
           output_location,
           aws_conn_id="aws_default",
           workgroup="primary",
           query_execution_context=None,
           result_configuration=None,
           sleep_time=30,
           max_tries=None,
           *args,
           **kwargs
       ):
           super().__init__(*args, **kwargs)
           self.queries = queries
           self.database = database
           self.output_location = output_location
           self.aws_conn_id = aws_conn_id
           self.workgroup = workgroup
           self.query_execution_context = query_execution_context or {}
           self.result_configuration = result_configuration or {}
           self.sleep_time = sleep_time
           self.max_tries = max_tries
           self.query_execution_id = None
           self.hook = None
           self.query_execution_ids = []
   
       def get_hook(self):
           """Create and return an AWSAthenaHook."""
           return AWSAthenaHook(self.aws_conn_id, self.sleep_time)
   
       def execute(self, context):
           """
           Run Presto Query on Athena
           """
           self.hook = self.get_hook()
   
           self.query_execution_context['Database'] = self.database
           self.result_configuration['OutputLocation'] = self.output_location
   
           batch = self.queries.split(";\n")
   
           for query in batch:
                   self.client_request_token = str(uuid4())  # new each time 
for idempotency
                   self.query_execution_id = self.hook.run_query(self.query, 
self.query_execution_context,
                                                               
self.result_configuration, self.client_request_token,
                                                                self.workgroup)
                   query_status = 
self.hook.poll_query_status(self.query_execution_id, self.max_tries)
   
                   if query_status in AWSAthenaHook.FAILURE_STATES:
                       error_message = 
self.hook.get_state_change_reason(self.query_execution_id)
                       raise Exception(
                           'Final state of Athena job is {}, query_execution_id 
is {}. Error: {}'
                           .format(query_status, self.query_execution_id, 
error_message))
                   elif not query_status or query_status in 
AWSAthenaHook.INTERMEDIATE_STATES:
                       raise Exception(
                           'Final state of Athena job is {}. '
                           'Max tries of poll status exceeded, 
query_execution_id is {}.'
                           .format(query_status, self.query_execution_id))
                   self.query_execution_ids.append(self.query_execution_id)
   
           return query_execution_ids 
   ``` 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to