utkarsharma2 commented on a change in pull request #9907:
URL: https://github.com/apache/airflow/pull/9907#discussion_r478550030
##########
File path: tests/cli/commands/test_connection_command.py
##########
@@ -321,3 +330,87 @@ def test_cli_delete_invalid_connection(self):
# Check deletion attempt stdout
self.assertIn("\tDid not find a connection with `conn_id`=fake",
stdout)
+
+
+class TestCliImportConnections(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ cls.parser = cli_parser.get_parser()
+ clear_db_connections()
+
+ @classmethod
+ def tearDownClass(cls):
+ clear_db_connections()
+
+ @parameterized.expand(
+ (
+ ({"CONN_ID": {'conn_type': 'mysql', 'host': 'host_1'}},
+ {"CONN_ID": {'conn_type': 'mysql', 'host': 'host_1'}}),
+ )
+ )
+ def test_connections_import(self, file_content, expected_connection_uris):
+ """Test connections_import command"""
+
+ with mock_local_file(json.dumps(file_content)):
+
connection_command.connections_import(self.parser.parse_args(['connections',
'import', "a.json"]))
+ with create_session() as session:
+ for conn_id in expected_connection_uris:
+ current_conn =
session.query(Connection).filter(Connection.conn_id == conn_id).first()
+ self.assertEqual(expected_connection_uris[conn_id],
+ {attr: getattr(current_conn, attr)
+ for attr in
expected_connection_uris[conn_id]})
+
+ @parameterized.expand(
+ (
+ (
+ {"CONN_ID2": [{'conn_type': 'mysql', 'host': 'host_1'},
+ {'conn_type': 'mysql', 'host': 'host_2'}]},
+ {"CONN_ID2": [{'conn_type': 'mysql', 'host': 'host_1'},
+ {'conn_type': 'mysql', 'host': 'host_2'}]}
+ ),
+ )
+ )
+ def test_connections_import_disposition_ignore(self, file_content,
expected_connection_uris):
+ """Test connections_import command with --conflict-disposition
ignore"""
+ with mock_local_file(json.dumps(file_content)):
+ connection_command.connections_import(self.parser.parse_args([
+ 'connections', 'import', 'a.json']))
+
+ connection_command.connections_import(self.parser.parse_args([
+ 'connections', 'import', 'a.json', '--conflict-disposition',
'ignore']))
+
+ conn_id = 'CONN_ID2'
+ with create_session() as session:
+ current_conn =
session.query(Connection).filter(Connection.conn_id == conn_id).first()
+ self.assertEqual(expected_connection_uris[conn_id][0],
+ {attr: getattr(current_conn, attr)
+ for attr in
expected_connection_uris[conn_id][0]})
+
+ @parameterized.expand(
+ (
+ (
+ {"CONN_ID3": [{'conn_type': 'mysql', 'host': 'host_1'},
+ {'conn_type': 'mysql', 'host': 'host_2'}]},
+ {"CONN_ID3": [{'conn_type': 'mysql', 'host': 'host_1'},
+ {'conn_type': 'mysql', 'host': 'host_2'}]}
+ ),
+ )
+ )
+ def test_connections_import_disposition_overwrite(self, file_content,
expected_connection_uris):
+ """Test connections_import command with --conflict-disposition
overwrite"""
+ with mock_local_file(json.dumps(file_content)):
+ connection_command.connections_import(self.parser.parse_args([
+ 'connections', 'import', 'a.json', '--conflict-disposition',
'overwrite']))
+ connection_command.connections_import(self.parser.parse_args([
+ 'connections', 'import', 'a.json', '--conflict-disposition',
'overwrite']))
+
+ with redirect_stdout(io.StringIO()) as stdout:
+ stdout = stdout.getvalue()
+ print(stdout)
+
+ conn_id = 'CONN_ID3'
+ with create_session() as session:
+ current_conn =
session.query(Connection).filter(Connection.conn_id == conn_id).first()
+ self.assertEqual(expected_connection_uris[conn_id][1],
Review comment:
Yup, make sense, updated in the PR.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]