baolsen commented on a change in pull request #6773: [AIRFLOW-6038] AWS 
DataSync example_dags added
URL: https://github.com/apache/airflow/pull/6773#discussion_r359157698
 
 

 ##########
 File path: tests/providers/amazon/aws/operators/test_datasync.py
 ##########
 @@ -197,88 +213,107 @@ def test_create_task(self, mock_get_conn):
         self.set_up_operator()
         # Delete all tasks:
         tasks = self.client.list_tasks()
-        for task in tasks['Tasks']:
-            self.client.delete_task(TaskArn=task['TaskArn'])
+        for task in tasks["Tasks"]:
+            self.client.delete_task(TaskArn=task["TaskArn"])
 
         # Check how many tasks and locations we have
         tasks = self.client.list_tasks()
-        self.assertEqual(len(tasks['Tasks']), 0)
+        self.assertEqual(len(tasks["Tasks"]), 0)
         locations = self.client.list_locations()
-        self.assertEqual(len(locations['Locations']), 2)
+        self.assertEqual(len(locations["Locations"]), 2)
 
         # Execute the task
         result = self.datasync.execute(None)
         self.assertIsNotNone(result)
-        task_arn = result
+        task_arn = result["TaskArn"]
 
         # Assert 1 additional task and 0 additional locations
         tasks = self.client.list_tasks()
-        self.assertEqual(len(tasks['Tasks']), 1)
+        self.assertEqual(len(tasks["Tasks"]), 1)
         locations = self.client.list_locations()
-        self.assertEqual(len(locations['Locations']), 2)
+        self.assertEqual(len(locations["Locations"]), 2)
 
         # Check task metadata
         task = self.client.describe_task(TaskArn=task_arn)
-        self.assertEqual(task['Options'], CREATE_TASK_KWARGS['Options'])
+        self.assertEqual(task["Options"], CREATE_TASK_KWARGS["Options"])
 
-    def test_create_task_even_if_one_exists(self, mock_get_conn):
+    def test_create_task_and_location(self, mock_get_conn):
         # ### Set up mocks:
         mock_get_conn.return_value = self.client
         # ### Begin tests:
 
         self.set_up_operator()
+        # Delete all tasks:
+        tasks = self.client.list_tasks()
+        for task in tasks["Tasks"]:
+            self.client.delete_task(TaskArn=task["TaskArn"])
+        # Delete all locations:
+        locations = self.client.list_locations()
+        for location in locations["Locations"]:
+            self.client.delete_location(LocationArn=location["LocationArn"])
 
         # Check how many tasks and locations we have
         tasks = self.client.list_tasks()
-        self.assertEqual(len(tasks['Tasks']), 1)
+        self.assertEqual(len(tasks["Tasks"]), 0)
         locations = self.client.list_locations()
-        self.assertEqual(len(locations['Locations']), 2)
+        self.assertEqual(len(locations["Locations"]), 0)
 
         # Execute the task
         result = self.datasync.execute(None)
         self.assertIsNotNone(result)
-        task_arn = result
 
-        # Assert 1 additional task and 0 additional locations
+        # Assert 1 additional task and 2 additional locations
         tasks = self.client.list_tasks()
-        self.assertEqual(len(tasks['Tasks']), 2)
+        self.assertEqual(len(tasks["Tasks"]), 1)
         locations = self.client.list_locations()
-        self.assertEqual(len(locations['Locations']), 2)
+        self.assertEqual(len(locations["Locations"]), 2)
 
-        # Check task metadata
-        task = self.client.describe_task(TaskArn=task_arn)
-        self.assertEqual(task['Options'], CREATE_TASK_KWARGS['Options'])
-
-    def test_create_task_and_location(self, mock_get_conn):
+    def test_dont_create_task(self, mock_get_conn):
         # ### Set up mocks:
         mock_get_conn.return_value = self.client
         # ### Begin tests:
 
-        self.set_up_operator()
-        # Delete all tasks:
         tasks = self.client.list_tasks()
-        for task in tasks['Tasks']:
-            self.client.delete_task(TaskArn=task['TaskArn'])
-        # Delete all locations:
-        locations = self.client.list_locations()
-        for location in locations['Locations']:
-            self.client.delete_location(LocationArn=location['LocationArn'])
+        tasks_before = len(tasks["Tasks"])
+
+        self.set_up_operator(task_arn=self.task_arn)
+        self.datasync.execute(None)
 
-        # Check how many tasks and locations we have
         tasks = self.client.list_tasks()
-        self.assertEqual(len(tasks['Tasks']), 0)
-        locations = self.client.list_locations()
-        self.assertEqual(len(locations['Locations']), 0)
+        tasks_after = len(tasks["Tasks"])
+        self.assertEqual(tasks_before, tasks_after)
 
-        # Execute the task
+    def create_task_many_locations(self, mock_get_conn):
+        # ### Set up mocks:
+        mock_get_conn.return_value = self.client
+        # ### Begin tests:
+
+        # Create duplicate source location to choose from
+        self.client.create_location_smb(
+            **MOCK_DATA["create_source_location_kwargs"]
+        )
+
+        self.set_up_operator(task_arn=self.task_arn)
+        with self.assertRaises(AirflowException):
+            self.datasync.execute(None)
+
+        self.set_up_operator(task_arn=self.task_arn, 
choose_location_strategy='random')
+        self.datasync.execute(None)
+
 
 Review comment:
   Good idea, I'll add something along those lines.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to