This is an automated email from the ASF dual-hosted git repository.
kamilbregula pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new ddee0aa Simplify load connection in LocalFilesystemBackend (#10638)
ddee0aa is described below
commit ddee0aa4fbb9709c08ec7a39f985e30bbf6c5ffb
Author: Kamil BreguĊa <[email protected]>
AuthorDate: Sun Sep 6 20:56:03 2020 +0200
Simplify load connection in LocalFilesystemBackend (#10638)
---
airflow/secrets/local_filesystem.py | 38 ++++++---
tests/secrets/test_local_filesystem.py | 147 ++++++++++++++++++---------------
2 files changed, 107 insertions(+), 78 deletions(-)
diff --git a/airflow/secrets/local_filesystem.py
b/airflow/secrets/local_filesystem.py
index 29754d5..4a26119 100644
--- a/airflow/secrets/local_filesystem.py
+++ b/airflow/secrets/local_filesystem.py
@@ -21,6 +21,7 @@ Objects relating to retrieving connections and variables from
local file
import json
import logging
import os
+import warnings
from collections import defaultdict
from inspect import signature
from json import JSONDecodeError
@@ -235,33 +236,44 @@ def load_variables(file_path: str) -> Dict[str, str]:
return variables
-def load_connections(file_path: str):
+def load_connections(file_path) -> Dict[str, List[Any]]:
+ """
+ This function is deprecated. Please use
`airflow.secrets.local_filesystem.load_connections_dict`.",
+ """
+ warnings.warn(
+ "This function is deprecated. Please use
`airflow.secrets.local_filesystem.load_connections_dict`.",
+ DeprecationWarning, stacklevel=2
+ )
+ return {k: [v] for k, v in load_connections_dict(file_path).values()}
+
+
+def load_connections_dict(file_path: str) -> Dict[str, Any]:
"""
Load connection from text file.
Both ``JSON`` and ``.env`` files are supported.
:return: A dictionary where the key contains a connection ID and the value
contains a list of connections.
- :rtype: Dict[str, List[airflow.models.connection.Connection]]
+ :rtype: Dict[str, airflow.models.connection.Connection]
"""
log.debug("Loading connection")
secrets: Dict[str, Any] = _parse_secret_file(file_path)
- connections_by_conn_id = defaultdict(list)
+ connection_by_conn_id = {}
for key, secret_values in list(secrets.items()):
if isinstance(secret_values, list):
+ if len(secret_values) > 1:
+ raise ConnectionNotUnique(f"Found multiple values for {key} in
{file_path}.")
+
for secret_value in secret_values:
- connections_by_conn_id[key].append(_create_connection(key,
secret_value))
+ connection_by_conn_id[key] = _create_connection(key,
secret_value)
else:
- connections_by_conn_id[key].append(_create_connection(key,
secret_values))
-
- if len(connections_by_conn_id[key]) > 1:
- raise ConnectionNotUnique(f"Found multiple values for {key} in
{file_path}")
+ connection_by_conn_id[key] = _create_connection(key, secret_values)
- num_conn = sum(map(len, connections_by_conn_id.values()))
+ num_conn = len(connection_by_conn_id)
log.debug("Loaded %d connections", num_conn)
- return connections_by_conn_id
+ return connection_by_conn_id
class LocalFilesystemBackend(BaseSecretsBackend, LoggingMixin):
@@ -298,10 +310,12 @@ class LocalFilesystemBackend(BaseSecretsBackend,
LoggingMixin):
self.log.debug("The file for connection is not specified.
Skipping")
# The user may not specify any file.
return {}
- return load_connections(self.connections_file)
+ return load_connections_dict(self.connections_file)
def get_connections(self, conn_id: str) -> List[Any]:
- return self._local_connections.get(conn_id) or []
+ if conn_id in self._local_connections:
+ return [self._local_connections[conn_id]]
+ return []
def get_variable(self, key: str) -> Optional[str]:
return self._local_variables.get(key)
diff --git a/tests/secrets/test_local_filesystem.py
b/tests/secrets/test_local_filesystem.py
index 97f6d42..61849b3 100644
--- a/tests/secrets/test_local_filesystem.py
+++ b/tests/secrets/test_local_filesystem.py
@@ -122,27 +122,27 @@ class TestLoadVariables(unittest.TestCase):
class TestLoadConnection(unittest.TestCase):
@parameterized.expand(
(
- ("CONN_ID=mysql://host_1/", {"CONN_ID": ["mysql://host_1"]}),
+ ("CONN_ID=mysql://host_1/", {"CONN_ID": "mysql://host_1"}),
(
"CONN_ID1=mysql://host_1/\nCONN_ID2=mysql://host_2/",
- {"CONN_ID1": ["mysql://host_1"], "CONN_ID2":
["mysql://host_2"]},
+ {"CONN_ID1": "mysql://host_1", "CONN_ID2": "mysql://host_2"},
),
(
"CONN_ID1=mysql://host_1/\n # AAAA\nCONN_ID2=mysql://host_2/",
- {"CONN_ID1": ["mysql://host_1"], "CONN_ID2":
["mysql://host_2"]},
+ {"CONN_ID1": "mysql://host_1", "CONN_ID2": "mysql://host_2"},
),
(
"\n\n\n\nCONN_ID1=mysql://host_1/\n\n\n\n\nCONN_ID2=mysql://host_2/\n\n\n",
- {"CONN_ID1": ["mysql://host_1"], "CONN_ID2":
["mysql://host_2"]},
+ {"CONN_ID1": "mysql://host_1", "CONN_ID2": "mysql://host_2"},
),
)
)
def test_env_file_should_load_connection(self, file_content,
expected_connection_uris):
with mock_local_file(file_content):
- connections_by_conn_id = local_filesystem.load_connections("a.env")
+ connection_by_conn_id =
local_filesystem.load_connections_dict("a.env")
connection_uris_by_conn_id = {
- conn_id: [connection.get_uri() for connection in connections]
- for conn_id, connections in connections_by_conn_id.items()
+ conn_id: connection.get_uri()
+ for conn_id, connection in connection_by_conn_id.items()
}
self.assertEqual(expected_connection_uris,
connection_uris_by_conn_id)
@@ -156,22 +156,22 @@ class TestLoadConnection(unittest.TestCase):
def test_env_file_invalid_format(self, content, expected_message):
with mock_local_file(content):
with self.assertRaisesRegex(AirflowFileParseException,
re.escape(expected_message)):
- local_filesystem.load_connections("a.env")
+ local_filesystem.load_connections_dict("a.env")
@parameterized.expand(
(
- ({"CONN_ID": "mysql://host_1"}, {"CONN_ID": ["mysql://host_1"]}),
- ({"CONN_ID": ["mysql://host_1"]}, {"CONN_ID": ["mysql://host_1"]}),
- ({"CONN_ID": {"uri": "mysql://host_1"}}, {"CONN_ID":
["mysql://host_1"]}),
- ({"CONN_ID": [{"uri": "mysql://host_1"}]}, {"CONN_ID":
["mysql://host_1"]}),
+ ({"CONN_ID": "mysql://host_1"}, {"CONN_ID": "mysql://host_1"}),
+ ({"CONN_ID": ["mysql://host_1"]}, {"CONN_ID": "mysql://host_1"}),
+ ({"CONN_ID": {"uri": "mysql://host_1"}}, {"CONN_ID":
"mysql://host_1"}),
+ ({"CONN_ID": [{"uri": "mysql://host_1"}]}, {"CONN_ID":
"mysql://host_1"}),
)
)
def test_json_file_should_load_connection(self, file_content,
expected_connection_uris):
with mock_local_file(json.dumps(file_content)):
- connections_by_conn_id =
local_filesystem.load_connections("a.json")
+ connections_by_conn_id =
local_filesystem.load_connections_dict("a.json")
connection_uris_by_conn_id = {
- conn_id: [connection.get_uri() for connection in connections]
- for conn_id, connections in connections_by_conn_id.items()
+ conn_id: connection.get_uri()
+ for conn_id, connection in connections_by_conn_id.items()
}
self.assertEqual(expected_connection_uris,
connection_uris_by_conn_id)
@@ -181,15 +181,16 @@ class TestLoadConnection(unittest.TestCase):
({"CONN_ID": None}, "Unexpected value type: <class 'NoneType'>."),
({"CONN_ID": 1}, "Unexpected value type: <class 'int'>."),
({"CONN_ID": [2]}, "Unexpected value type: <class 'int'>."),
- ({"CONN_ID": ["mysql://host_1", None]}, "Unexpected value type:
<class 'NoneType'>."),
+ ({"CONN_ID": [None]}, "Unexpected value type: <class
'NoneType'>."),
({"CONN_ID": {"AAA": "mysql://host_1"}}, "The object have illegal
keys: AAA."),
({"CONN_ID": {"conn_id": "BBBB"}}, "Mismatch conn_id."),
+ ({"CONN_ID": ["mysql://", "mysql://"]}, "Found multiple values for
CONN_ID in a.json."),
)
)
def test_env_file_invalid_input(self, file_content,
expected_connection_uris):
with mock_local_file(json.dumps(file_content)):
with self.assertRaisesRegex(AirflowException,
re.escape(expected_connection_uris)):
- local_filesystem.load_connections("a.json")
+ local_filesystem.load_connections_dict("a.json")
@mock.patch("airflow.secrets.local_filesystem.os.path.exists",
return_value=False)
def test_missing_file(self, mock_exists):
@@ -197,11 +198,11 @@ class TestLoadConnection(unittest.TestCase):
AirflowException,
re.escape("File a.json was not found. Check the configuration of
your Secrets backend."),
):
- local_filesystem.load_connections("a.json")
+ local_filesystem.load_connections_dict("a.json")
@parameterized.expand(
(
- ("""CONN_A: 'mysql://host_a'""", {"CONN_A": ["mysql://host_a"]}),
+ ("""CONN_A: 'mysql://host_a'""", {"CONN_A": "mysql://host_a"}),
("""
conn_a: mysql://hosta
conn_b:
@@ -215,66 +216,80 @@ class TestLoadConnection(unittest.TestCase):
extra__google_cloud_platform__keyfile_dict:
a: b
extra__google_cloud_platform__keyfile_path: asaa""",
- {"conn_a": ["mysql://hosta"],
- "conn_b":
[''.join("""scheme://Login:None@host:1234/lschema?
+ {"conn_a": "mysql://hosta",
+ "conn_b": ''.join("""scheme://Login:None@host:1234/lschema?
extra__google_cloud_platform__keyfile_dict=%7B%27a%27%3A+%27b%27%7D
-
&extra__google_cloud_platform__keyfile_path=asaa""".split())]}),
+
&extra__google_cloud_platform__keyfile_path=asaa""".split())}),
)
)
def test_yaml_file_should_load_connection(self, file_content,
expected_connection_uris):
with mock_local_file(file_content):
- connections_by_conn_id =
local_filesystem.load_connections("a.yaml")
+ connections_by_conn_id =
local_filesystem.load_connections_dict("a.yaml")
connection_uris_by_conn_id = {
- conn_id: [connection.get_uri() for connection in connections]
- for conn_id, connections in connections_by_conn_id.items()
+ conn_id: connection.get_uri()
+ for conn_id, connection in connections_by_conn_id.items()
}
self.assertEqual(expected_connection_uris,
connection_uris_by_conn_id)
@parameterized.expand(
(
- ("""conn_c:
- conn_type: scheme
- host: host
- schema: lschema
- login: Login
- password: None
- port: 1234
- extra_dejson:
- aws_conn_id: bbb
- region_name: ccc
- """, {"conn_c": [{"aws_conn_id": "bbb", "region_name":
"ccc"}]}),
- ("""conn_d:
- conn_type: scheme
- host: host
- schema: lschema
- login: Login
- password: None
- port: 1234
- extra_dejson:
- extra__google_cloud_platform__keyfile_dict:
- a: b
- extra__google_cloud_platform__key_path: xxx
- """, {"conn_d":
[{"extra__google_cloud_platform__keyfile_dict": {"a": "b"},
- "extra__google_cloud_platform__key_path":
"xxx"}]}),
- ("""conn_d:
- conn_type: scheme
- host: host
- schema: lschema
- login: Login
- password: None
- port: 1234
- extra: '{\"extra__google_cloud_platform__keyfile_dict\":
{\"a\": \"b\"}}'""", {"conn_d": [
- {"extra__google_cloud_platform__keyfile_dict": {"a": "b"}}]})
-
+ (
+ """
+ conn_c:
+ conn_type: scheme
+ host: host
+ schema: lschema
+ login: Login
+ password: None
+ port: 1234
+ extra_dejson:
+ aws_conn_id: bbb
+ region_name: ccc
+ """,
+ {"conn_c": {"aws_conn_id": "bbb", "region_name": "ccc"}},
+ ),
+ (
+ """
+ conn_d:
+ conn_type: scheme
+ host: host
+ schema: lschema
+ login: Login
+ password: None
+ port: 1234
+ extra_dejson:
+ extra__google_cloud_platform__keyfile_dict:
+ a: b
+ extra__google_cloud_platform__key_path: xxx
+ """,
+ {
+ "conn_d": {
+ "extra__google_cloud_platform__keyfile_dict": {"a":
"b"},
+ "extra__google_cloud_platform__key_path": "xxx",
+ }
+ },
+ ),
+ (
+ """
+ conn_d:
+ conn_type: scheme
+ host: host
+ schema: lschema
+ login: Login
+ password: None
+ port: 1234
+ extra: '{\"extra__google_cloud_platform__keyfile_dict\":
{\"a\": \"b\"}}'
+ """,
+ {"conn_d": {"extra__google_cloud_platform__keyfile_dict":
{"a": "b"}}},
+ ),
)
)
def test_yaml_file_should_load_connection_extras(self, file_content,
expected_extras):
with mock_local_file(file_content):
- connections_by_conn_id =
local_filesystem.load_connections("a.yaml")
+ connections_by_conn_id =
local_filesystem.load_connections_dict("a.yaml")
connection_uris_by_conn_id = {
- conn_id: [connection.extra_dejson for connection in
connections]
- for conn_id, connections in connections_by_conn_id.items()
+ conn_id: connection.extra_dejson for conn_id, connection in
connections_by_conn_id.items()
}
self.assertEqual(expected_extras, connection_uris_by_conn_id)
@@ -298,7 +313,7 @@ class TestLoadConnection(unittest.TestCase):
def test_yaml_invalid_extra(self, file_content, expected_message):
with mock_local_file(file_content):
with self.assertRaisesRegex(AirflowException,
re.escape(expected_message)):
- local_filesystem.load_connections("a.yaml")
+ local_filesystem.load_connections_dict("a.yaml")
@parameterized.expand(
(
@@ -308,7 +323,7 @@ class TestLoadConnection(unittest.TestCase):
def test_ensure_unique_connection_env(self, file_content):
with mock_local_file(file_content):
with self.assertRaises(ConnectionNotUnique):
- local_filesystem.load_connections("a.env")
+ local_filesystem.load_connections_dict("a.env")
@parameterized.expand(
(
@@ -323,7 +338,7 @@ class TestLoadConnection(unittest.TestCase):
def test_ensure_unique_connection_json(self, file_content):
with mock_local_file(json.dumps(file_content)):
with self.assertRaises(ConnectionNotUnique):
- local_filesystem.load_connections("a.json")
+ local_filesystem.load_connections_dict("a.json")
@parameterized.expand(
(
@@ -336,7 +351,7 @@ class TestLoadConnection(unittest.TestCase):
def test_ensure_unique_connection_yaml(self, file_content):
with mock_local_file(file_content):
with self.assertRaises(ConnectionNotUnique):
- local_filesystem.load_connections("a.yaml")
+ local_filesystem.load_connections_dict("a.yaml")
class TestLocalFileBackend(unittest.TestCase):