This is an automated email from the ASF dual-hosted git repository.

feluelle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new dea345b  Fix AwsGlueJobSensor to stop running after the Glue job 
finished (#9022)
dea345b is described below

commit dea345b05c2cd226e70f97a3934d7456aa1cc754
Author: Jubeen Lee <[email protected]>
AuthorDate: Tue Aug 18 01:41:50 2020 +0900

    Fix AwsGlueJobSensor to stop running after the Glue job finished (#9022)
    
    * Extract get_job_state and fix poke of AwsGlueJobSensor
    
    * Save hook and reuse in GlueJobSensor
    
    * Add descriptions for some functions
    
    * Fix tests according to changed function definition
    
    * Fix too long line
    
    * Add type hints and apply review
    
    * Fix type error
    
    Co-authored-by: JB Lee <[email protected]>
---
 airflow/providers/amazon/aws/hooks/glue.py        | 39 ++++++++++++++++-------
 airflow/providers/amazon/aws/operators/glue.py    |  1 +
 airflow/providers/amazon/aws/sensors/glue.py      |  5 ++-
 tests/providers/amazon/aws/hooks/test_glue.py     | 21 ++++++------
 tests/providers/amazon/aws/operators/test_glue.py |  8 ++++-
 tests/providers/amazon/aws/sensors/test_glue.py   | 12 +++----
 6 files changed, 55 insertions(+), 31 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/glue.py 
b/airflow/providers/amazon/aws/hooks/glue.py
index d7f538f..9db925d 100644
--- a/airflow/providers/amazon/aws/hooks/glue.py
+++ b/airflow/providers/amazon/aws/hooks/glue.py
@@ -108,29 +108,46 @@ class AwsGlueJobHook(AwsBaseHook):
                 JobName=job_name,
                 Arguments=script_arguments
             )
-            return self.job_completion(job_name, job_run['JobRunId'])
+            return job_run
         except Exception as general_error:
             self.log.error("Failed to run aws glue job, error: %s", 
general_error)
             raise
 
+    def get_job_state(self, job_name: str, run_id: str) -> str:
+        """
+        Get state of the Glue job. The job state can be
+        running, finished, failed, stopped or timeout.
+        :param job_name: unique job name per AWS account
+        :type job_name: str
+        :param run_id: The job-run ID of the predecessor job run
+        :type run_id: str
+        :return: State of the Glue job
+        """
+        glue_client = self.get_conn()
+        job_run = glue_client.get_job_run(
+            JobName=job_name,
+            RunId=run_id,
+            PredecessorsIncluded=True
+        )
+        job_run_state = job_run['JobRun']['JobRunState']
+        return job_run_state
+
     def job_completion(self, job_name: str, run_id: str) -> Dict[str, str]:
         """
+        Waits until Glue job with job_name completes or
+        fails and return final state if finished.
+        Raises AirflowException when the job failed
         :param job_name: unique job name per AWS account
         :type job_name: str
         :param run_id: The job-run ID of the predecessor job run
         :type run_id: str
-        :return: Status of the Job if succeeded or stopped
+        :return: Dict of JobRunState and JobRunId
         """
+        failed_states = ['FAILED', 'TIMEOUT']
+        finished_states = ['SUCCEEDED', 'STOPPED']
+
         while True:
-            glue_client = self.get_conn()
-            job_status = glue_client.get_job_run(
-                JobName=job_name,
-                RunId=run_id,
-                PredecessorsIncluded=True
-            )
-            job_run_state = job_status['JobRun']['JobRunState']
-            failed_states = ['FAILED', 'TIMEOUT']
-            finished_states = ['SUCCEEDED', 'STOPPED']
+            job_run_state = self.get_job_state(job_name, run_id)
             if job_run_state in finished_states:
                 self.log.info("Exiting Job %s Run State: %s", run_id, 
job_run_state)
                 return {'JobRunState': job_run_state, 'JobRunId': run_id}
diff --git a/airflow/providers/amazon/aws/operators/glue.py 
b/airflow/providers/amazon/aws/operators/glue.py
index 055a43d..a945f4e 100644
--- a/airflow/providers/amazon/aws/operators/glue.py
+++ b/airflow/providers/amazon/aws/operators/glue.py
@@ -108,6 +108,7 @@ class AwsGlueJobOperator(BaseOperator):
                                   iam_role_name=self.iam_role_name)
         self.log.info("Initializing AWS Glue Job: %s", self.job_name)
         glue_job_run = glue_job.initialize_job(self.script_args)
+        glue_job_run = glue_job.job_completion(self.job_name, 
glue_job_run['JobRunId'])
         self.log.info(
             "AWS Glue Job: %s status: %s. Run Id: %s",
             self.job_name, glue_job_run['JobRunState'], 
glue_job_run['JobRunId'])
diff --git a/airflow/providers/amazon/aws/sensors/glue.py 
b/airflow/providers/amazon/aws/sensors/glue.py
index 4525602..9539761 100644
--- a/airflow/providers/amazon/aws/sensors/glue.py
+++ b/airflow/providers/amazon/aws/sensors/glue.py
@@ -48,12 +48,11 @@ class AwsGlueJobSensor(BaseSensorOperator):
         self.errored_states = ['FAILED', 'STOPPED', 'TIMEOUT']
 
     def poke(self, context):
+        hook = AwsGlueJobHook(aws_conn_id=self.aws_conn_id)
         self.log.info(
             "Poking for job run status :"
             "for Glue Job %s and ID %s", self.job_name, self.run_id)
-        hook = AwsGlueJobHook(aws_conn_id=self.aws_conn_id)
-        job_state = hook.job_completion(job_name=self.job_name,
-                                        run_id=self.run_id)
+        job_state = hook.get_job_state(job_name=self.job_name, 
run_id=self.run_id)
         if job_state in self.success_states:
             self.log.info("Exiting Job %s Run State: %s", self.run_id, 
job_state)
             return True
diff --git a/tests/providers/amazon/aws/hooks/test_glue.py 
b/tests/providers/amazon/aws/hooks/test_glue.py
index 8815166..3871025 100644
--- a/tests/providers/amazon/aws/hooks/test_glue.py
+++ b/tests/providers/amazon/aws/hooks/test_glue.py
@@ -73,12 +73,12 @@ class TestGlueJobHook(unittest.TestCase):
             .get_or_create_glue_job()
         self.assertEqual(glue_job, mock_glue_job)
 
-    @mock.patch.object(AwsGlueJobHook, "job_completion")
+    @mock.patch.object(AwsGlueJobHook, "get_job_state")
     @mock.patch.object(AwsGlueJobHook, "get_or_create_glue_job")
     @mock.patch.object(AwsGlueJobHook, "get_conn")
     def test_initialize_job(self, mock_get_conn,
                             mock_get_or_create_glue_job,
-                            mock_completion):
+                            mock_get_job_state):
         some_data_path = "s3://glue-datasets/examples/medicare/SampleData.csv"
         some_script_arguments = {"--s3_input_data_path": some_data_path}
         some_script = "s3:/glue-examples/glue-scripts/sample_aws_glue_job.py"
@@ -87,14 +87,15 @@ class TestGlueJobHook(unittest.TestCase):
         mock_get_or_create_glue_job.Name = mock.Mock(Name='aws_test_glue_job')
         mock_get_conn.return_value.start_job_run()
 
-        mock_job_run_state = mock_completion.return_value
-        glue_job_run_state = AwsGlueJobHook(job_name='aws_test_glue_job',
-                                            desc='This is test case job from 
Airflow',
-                                            iam_role_name='my_test_role',
-                                            script_location=some_script,
-                                            s3_bucket=some_s3_bucket,
-                                            region_name=self.some_aws_region)\
-            .initialize_job(some_script_arguments)
+        mock_job_run_state = mock_get_job_state.return_value
+        glue_job_hook = AwsGlueJobHook(job_name='aws_test_glue_job',
+                                       desc='This is test case job from 
Airflow',
+                                       iam_role_name='my_test_role',
+                                       script_location=some_script,
+                                       s3_bucket=some_s3_bucket,
+                                       region_name=self.some_aws_region)
+        glue_job_run = glue_job_hook.initialize_job(some_script_arguments)
+        glue_job_run_state = 
glue_job_hook.get_job_state(glue_job_run['JobName'], glue_job_run['JobRunId'])
         self.assertEqual(glue_job_run_state, mock_job_run_state, msg='Mocks 
but be equal')
 
 
diff --git a/tests/providers/amazon/aws/operators/test_glue.py 
b/tests/providers/amazon/aws/operators/test_glue.py
index 93d9d70..83c2c0f 100644
--- a/tests/providers/amazon/aws/operators/test_glue.py
+++ b/tests/providers/amazon/aws/operators/test_glue.py
@@ -43,11 +43,17 @@ class TestAwsGlueJobOperator(unittest.TestCase):
                                        s3_bucket='some_bucket',
                                        iam_role_name='my_test_role')
 
+    @mock.patch.object(AwsGlueJobHook, 'get_job_state')
     @mock.patch.object(AwsGlueJobHook, 'initialize_job')
     @mock.patch.object(AwsGlueJobHook, "get_conn")
     @mock.patch.object(S3Hook, "load_file")
-    def test_execute_without_failure(self, mock_load_file, mock_get_conn, 
mock_initialize_job):
+    def test_execute_without_failure(self,
+                                     mock_load_file,
+                                     mock_get_conn,
+                                     mock_initialize_job,
+                                     mock_get_job_state):
         mock_initialize_job.return_value = {'JobRunState': 'RUNNING', 
'JobRunId': '11111'}
+        mock_get_job_state.return_value = 'SUCCEEDED'
         self.glue.execute(None)
 
         mock_initialize_job.assert_called_once_with({})
diff --git a/tests/providers/amazon/aws/sensors/test_glue.py 
b/tests/providers/amazon/aws/sensors/test_glue.py
index 70f407a..ec6f921 100644
--- a/tests/providers/amazon/aws/sensors/test_glue.py
+++ b/tests/providers/amazon/aws/sensors/test_glue.py
@@ -32,10 +32,10 @@ class TestAwsGlueJobSensor(unittest.TestCase):
         configuration.load_test_config()
 
     @mock.patch.object(AwsGlueJobHook, 'get_conn')
-    @mock.patch.object(AwsGlueJobHook, 'job_completion')
-    def test_poke(self, mock_job_completion, mock_conn):
+    @mock.patch.object(AwsGlueJobHook, 'get_job_state')
+    def test_poke(self, mock_get_job_state, mock_conn):
         mock_conn.return_value.get_job_run()
-        mock_job_completion.return_value = 'SUCCEEDED'
+        mock_get_job_state.return_value = 'SUCCEEDED'
         op = AwsGlueJobSensor(task_id='test_glue_job_sensor',
                               job_name='aws_test_glue_job',
                               run_id='5152fgsfsjhsh61661',
@@ -45,10 +45,10 @@ class TestAwsGlueJobSensor(unittest.TestCase):
         self.assertTrue(op.poke(None))
 
     @mock.patch.object(AwsGlueJobHook, 'get_conn')
-    @mock.patch.object(AwsGlueJobHook, 'job_completion')
-    def test_poke_false(self, mock_job_completion, mock_conn):
+    @mock.patch.object(AwsGlueJobHook, 'get_job_state')
+    def test_poke_false(self, mock_get_job_state, mock_conn):
         mock_conn.return_value.get_job_run()
-        mock_job_completion.return_value = 'RUNNING'
+        mock_get_job_state.return_value = 'RUNNING'
         op = AwsGlueJobSensor(task_id='test_glue_job_sensor',
                               job_name='aws_test_glue_job',
                               run_id='5152fgsfsjhsh61661',

Reply via email to