baolsen commented on a change in pull request #6773: [AIRFLOW-6038] AWS
DataSync example_dags added
URL: https://github.com/apache/airflow/pull/6773#discussion_r359160672
##########
File path: tests/providers/amazon/aws/operators/test_datasync.py
##########
@@ -197,88 +213,107 @@ def test_create_task(self, mock_get_conn):
self.set_up_operator()
# Delete all tasks:
tasks = self.client.list_tasks()
- for task in tasks['Tasks']:
- self.client.delete_task(TaskArn=task['TaskArn'])
+ for task in tasks["Tasks"]:
+ self.client.delete_task(TaskArn=task["TaskArn"])
# Check how many tasks and locations we have
tasks = self.client.list_tasks()
- self.assertEqual(len(tasks['Tasks']), 0)
+ self.assertEqual(len(tasks["Tasks"]), 0)
locations = self.client.list_locations()
- self.assertEqual(len(locations['Locations']), 2)
+ self.assertEqual(len(locations["Locations"]), 2)
# Execute the task
result = self.datasync.execute(None)
self.assertIsNotNone(result)
- task_arn = result
+ task_arn = result["TaskArn"]
# Assert 1 additional task and 0 additional locations
tasks = self.client.list_tasks()
- self.assertEqual(len(tasks['Tasks']), 1)
+ self.assertEqual(len(tasks["Tasks"]), 1)
locations = self.client.list_locations()
- self.assertEqual(len(locations['Locations']), 2)
+ self.assertEqual(len(locations["Locations"]), 2)
# Check task metadata
task = self.client.describe_task(TaskArn=task_arn)
- self.assertEqual(task['Options'], CREATE_TASK_KWARGS['Options'])
+ self.assertEqual(task["Options"], CREATE_TASK_KWARGS["Options"])
- def test_create_task_even_if_one_exists(self, mock_get_conn):
+ def test_create_task_and_location(self, mock_get_conn):
# ### Set up mocks:
mock_get_conn.return_value = self.client
# ### Begin tests:
self.set_up_operator()
+ # Delete all tasks:
+ tasks = self.client.list_tasks()
+ for task in tasks["Tasks"]:
+ self.client.delete_task(TaskArn=task["TaskArn"])
+ # Delete all locations:
+ locations = self.client.list_locations()
+ for location in locations["Locations"]:
+ self.client.delete_location(LocationArn=location["LocationArn"])
# Check how many tasks and locations we have
tasks = self.client.list_tasks()
- self.assertEqual(len(tasks['Tasks']), 1)
+ self.assertEqual(len(tasks["Tasks"]), 0)
locations = self.client.list_locations()
- self.assertEqual(len(locations['Locations']), 2)
+ self.assertEqual(len(locations["Locations"]), 0)
# Execute the task
result = self.datasync.execute(None)
self.assertIsNotNone(result)
- task_arn = result
- # Assert 1 additional task and 0 additional locations
+ # Assert 1 additional task and 2 additional locations
tasks = self.client.list_tasks()
- self.assertEqual(len(tasks['Tasks']), 2)
+ self.assertEqual(len(tasks["Tasks"]), 1)
locations = self.client.list_locations()
- self.assertEqual(len(locations['Locations']), 2)
+ self.assertEqual(len(locations["Locations"]), 2)
- # Check task metadata
- task = self.client.describe_task(TaskArn=task_arn)
- self.assertEqual(task['Options'], CREATE_TASK_KWARGS['Options'])
-
- def test_create_task_and_location(self, mock_get_conn):
+ def test_dont_create_task(self, mock_get_conn):
# ### Set up mocks:
mock_get_conn.return_value = self.client
# ### Begin tests:
- self.set_up_operator()
- # Delete all tasks:
tasks = self.client.list_tasks()
- for task in tasks['Tasks']:
- self.client.delete_task(TaskArn=task['TaskArn'])
- # Delete all locations:
- locations = self.client.list_locations()
- for location in locations['Locations']:
- self.client.delete_location(LocationArn=location['LocationArn'])
+ tasks_before = len(tasks["Tasks"])
+
+ self.set_up_operator(task_arn=self.task_arn)
+ self.datasync.execute(None)
- # Check how many tasks and locations we have
tasks = self.client.list_tasks()
- self.assertEqual(len(tasks['Tasks']), 0)
- locations = self.client.list_locations()
- self.assertEqual(len(locations['Locations']), 0)
+ tasks_after = len(tasks["Tasks"])
+ self.assertEqual(tasks_before, tasks_after)
- # Execute the task
+ def create_task_many_locations(self, mock_get_conn):
+ # ### Set up mocks:
+ mock_get_conn.return_value = self.client
+ # ### Begin tests:
+
+ # Create duplicate source location to choose from
+ self.client.create_location_smb(
+ **MOCK_DATA["create_source_location_kwargs"]
+ )
+
+ self.set_up_operator(task_arn=self.task_arn)
+ with self.assertRaises(AirflowException):
+ self.datasync.execute(None)
+
+ self.set_up_operator(task_arn=self.task_arn,
choose_location_strategy='random')
+ self.datasync.execute(None)
+
Review comment:
In most of the tests the get_hook() method will be called more than once,
except when we just test init success / fail. So I've added asserts for called
vs not called.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services