This is an automated email from the ASF dual-hosted git repository.

kamilbregula pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new ddee0aa  Simplify load connection in LocalFilesystemBackend (#10638)
ddee0aa is described below

commit ddee0aa4fbb9709c08ec7a39f985e30bbf6c5ffb
Author: Kamil BreguĊ‚a <[email protected]>
AuthorDate: Sun Sep 6 20:56:03 2020 +0200

    Simplify load connection in LocalFilesystemBackend (#10638)
---
 airflow/secrets/local_filesystem.py    |  38 ++++++---
 tests/secrets/test_local_filesystem.py | 147 ++++++++++++++++++---------------
 2 files changed, 107 insertions(+), 78 deletions(-)

diff --git a/airflow/secrets/local_filesystem.py 
b/airflow/secrets/local_filesystem.py
index 29754d5..4a26119 100644
--- a/airflow/secrets/local_filesystem.py
+++ b/airflow/secrets/local_filesystem.py
@@ -21,6 +21,7 @@ Objects relating to retrieving connections and variables from 
local file
 import json
 import logging
 import os
+import warnings
 from collections import defaultdict
 from inspect import signature
 from json import JSONDecodeError
@@ -235,33 +236,44 @@ def load_variables(file_path: str) -> Dict[str, str]:
     return variables
 
 
-def load_connections(file_path: str):
+def load_connections(file_path) -> Dict[str, List[Any]]:
+    """
+    This function is deprecated. Please use 
`airflow.secrets.local_filesystem.load_connections_dict`.",
+    """
+    warnings.warn(
+        "This function is deprecated. Please use 
`airflow.secrets.local_filesystem.load_connections_dict`.",
+        DeprecationWarning, stacklevel=2
+    )
+    return {k: [v] for k, v in load_connections_dict(file_path).values()}
+
+
+def load_connections_dict(file_path: str) -> Dict[str, Any]:
     """
     Load connection from text file.
 
     Both ``JSON`` and ``.env`` files are supported.
 
     :return: A dictionary where the key contains a connection ID and the value 
contains a list of connections.
-    :rtype: Dict[str, List[airflow.models.connection.Connection]]
+    :rtype: Dict[str, airflow.models.connection.Connection]
     """
     log.debug("Loading connection")
 
     secrets: Dict[str, Any] = _parse_secret_file(file_path)
-    connections_by_conn_id = defaultdict(list)
+    connection_by_conn_id = {}
     for key, secret_values in list(secrets.items()):
         if isinstance(secret_values, list):
+            if len(secret_values) > 1:
+                raise ConnectionNotUnique(f"Found multiple values for {key} in 
{file_path}.")
+
             for secret_value in secret_values:
-                connections_by_conn_id[key].append(_create_connection(key, 
secret_value))
+                connection_by_conn_id[key] = _create_connection(key, 
secret_value)
         else:
-            connections_by_conn_id[key].append(_create_connection(key, 
secret_values))
-
-        if len(connections_by_conn_id[key]) > 1:
-            raise ConnectionNotUnique(f"Found multiple values for {key} in 
{file_path}")
+            connection_by_conn_id[key] = _create_connection(key, secret_values)
 
-    num_conn = sum(map(len, connections_by_conn_id.values()))
+    num_conn = len(connection_by_conn_id)
     log.debug("Loaded %d connections", num_conn)
 
-    return connections_by_conn_id
+    return connection_by_conn_id
 
 
 class LocalFilesystemBackend(BaseSecretsBackend, LoggingMixin):
@@ -298,10 +310,12 @@ class LocalFilesystemBackend(BaseSecretsBackend, 
LoggingMixin):
             self.log.debug("The file for connection is not specified. 
Skipping")
             # The user may not specify any file.
             return {}
-        return load_connections(self.connections_file)
+        return load_connections_dict(self.connections_file)
 
     def get_connections(self, conn_id: str) -> List[Any]:
-        return self._local_connections.get(conn_id) or []
+        if conn_id in self._local_connections:
+            return [self._local_connections[conn_id]]
+        return []
 
     def get_variable(self, key: str) -> Optional[str]:
         return self._local_variables.get(key)
diff --git a/tests/secrets/test_local_filesystem.py 
b/tests/secrets/test_local_filesystem.py
index 97f6d42..61849b3 100644
--- a/tests/secrets/test_local_filesystem.py
+++ b/tests/secrets/test_local_filesystem.py
@@ -122,27 +122,27 @@ class TestLoadVariables(unittest.TestCase):
 class TestLoadConnection(unittest.TestCase):
     @parameterized.expand(
         (
-            ("CONN_ID=mysql://host_1/", {"CONN_ID": ["mysql://host_1"]}),
+            ("CONN_ID=mysql://host_1/", {"CONN_ID": "mysql://host_1"}),
             (
                 "CONN_ID1=mysql://host_1/\nCONN_ID2=mysql://host_2/",
-                {"CONN_ID1": ["mysql://host_1"], "CONN_ID2": 
["mysql://host_2"]},
+                {"CONN_ID1": "mysql://host_1", "CONN_ID2": "mysql://host_2"},
             ),
             (
                 "CONN_ID1=mysql://host_1/\n # AAAA\nCONN_ID2=mysql://host_2/",
-                {"CONN_ID1": ["mysql://host_1"], "CONN_ID2": 
["mysql://host_2"]},
+                {"CONN_ID1": "mysql://host_1", "CONN_ID2": "mysql://host_2"},
             ),
             (
                 
"\n\n\n\nCONN_ID1=mysql://host_1/\n\n\n\n\nCONN_ID2=mysql://host_2/\n\n\n",
-                {"CONN_ID1": ["mysql://host_1"], "CONN_ID2": 
["mysql://host_2"]},
+                {"CONN_ID1": "mysql://host_1", "CONN_ID2": "mysql://host_2"},
             ),
         )
     )
     def test_env_file_should_load_connection(self, file_content, 
expected_connection_uris):
         with mock_local_file(file_content):
-            connections_by_conn_id = local_filesystem.load_connections("a.env")
+            connection_by_conn_id = 
local_filesystem.load_connections_dict("a.env")
             connection_uris_by_conn_id = {
-                conn_id: [connection.get_uri() for connection in connections]
-                for conn_id, connections in connections_by_conn_id.items()
+                conn_id: connection.get_uri()
+                for conn_id, connection in connection_by_conn_id.items()
             }
 
             self.assertEqual(expected_connection_uris, 
connection_uris_by_conn_id)
@@ -156,22 +156,22 @@ class TestLoadConnection(unittest.TestCase):
     def test_env_file_invalid_format(self, content, expected_message):
         with mock_local_file(content):
             with self.assertRaisesRegex(AirflowFileParseException, 
re.escape(expected_message)):
-                local_filesystem.load_connections("a.env")
+                local_filesystem.load_connections_dict("a.env")
 
     @parameterized.expand(
         (
-            ({"CONN_ID": "mysql://host_1"}, {"CONN_ID": ["mysql://host_1"]}),
-            ({"CONN_ID": ["mysql://host_1"]}, {"CONN_ID": ["mysql://host_1"]}),
-            ({"CONN_ID": {"uri": "mysql://host_1"}}, {"CONN_ID": 
["mysql://host_1"]}),
-            ({"CONN_ID": [{"uri": "mysql://host_1"}]}, {"CONN_ID": 
["mysql://host_1"]}),
+            ({"CONN_ID": "mysql://host_1"}, {"CONN_ID": "mysql://host_1"}),
+            ({"CONN_ID": ["mysql://host_1"]}, {"CONN_ID": "mysql://host_1"}),
+            ({"CONN_ID": {"uri": "mysql://host_1"}}, {"CONN_ID": 
"mysql://host_1"}),
+            ({"CONN_ID": [{"uri": "mysql://host_1"}]}, {"CONN_ID": 
"mysql://host_1"}),
         )
     )
     def test_json_file_should_load_connection(self, file_content, 
expected_connection_uris):
         with mock_local_file(json.dumps(file_content)):
-            connections_by_conn_id = 
local_filesystem.load_connections("a.json")
+            connections_by_conn_id = 
local_filesystem.load_connections_dict("a.json")
             connection_uris_by_conn_id = {
-                conn_id: [connection.get_uri() for connection in connections]
-                for conn_id, connections in connections_by_conn_id.items()
+                conn_id: connection.get_uri()
+                for conn_id, connection in connections_by_conn_id.items()
             }
 
             self.assertEqual(expected_connection_uris, 
connection_uris_by_conn_id)
@@ -181,15 +181,16 @@ class TestLoadConnection(unittest.TestCase):
             ({"CONN_ID": None}, "Unexpected value type: <class 'NoneType'>."),
             ({"CONN_ID": 1}, "Unexpected value type: <class 'int'>."),
             ({"CONN_ID": [2]}, "Unexpected value type: <class 'int'>."),
-            ({"CONN_ID": ["mysql://host_1", None]}, "Unexpected value type: 
<class 'NoneType'>."),
+            ({"CONN_ID": [None]}, "Unexpected value type: <class 
'NoneType'>."),
             ({"CONN_ID": {"AAA": "mysql://host_1"}}, "The object have illegal 
keys: AAA."),
             ({"CONN_ID": {"conn_id": "BBBB"}}, "Mismatch conn_id."),
+            ({"CONN_ID": ["mysql://", "mysql://"]}, "Found multiple values for 
CONN_ID in a.json."),
         )
     )
     def test_env_file_invalid_input(self, file_content, 
expected_connection_uris):
         with mock_local_file(json.dumps(file_content)):
             with self.assertRaisesRegex(AirflowException, 
re.escape(expected_connection_uris)):
-                local_filesystem.load_connections("a.json")
+                local_filesystem.load_connections_dict("a.json")
 
     @mock.patch("airflow.secrets.local_filesystem.os.path.exists", 
return_value=False)
     def test_missing_file(self, mock_exists):
@@ -197,11 +198,11 @@ class TestLoadConnection(unittest.TestCase):
             AirflowException,
             re.escape("File a.json was not found. Check the configuration of 
your Secrets backend."),
         ):
-            local_filesystem.load_connections("a.json")
+            local_filesystem.load_connections_dict("a.json")
 
     @parameterized.expand(
         (
-            ("""CONN_A: 'mysql://host_a'""", {"CONN_A": ["mysql://host_a"]}),
+            ("""CONN_A: 'mysql://host_a'""", {"CONN_A": "mysql://host_a"}),
             ("""
             conn_a: mysql://hosta
             conn_b:
@@ -215,66 +216,80 @@ class TestLoadConnection(unittest.TestCase):
                  extra__google_cloud_platform__keyfile_dict:
                    a: b
                  extra__google_cloud_platform__keyfile_path: asaa""",
-                {"conn_a": ["mysql://hosta"],
-                    "conn_b": 
[''.join("""scheme://Login:None@host:1234/lschema?
+                {"conn_a": "mysql://hosta",
+                    "conn_b": ''.join("""scheme://Login:None@host:1234/lschema?
                         
extra__google_cloud_platform__keyfile_dict=%7B%27a%27%3A+%27b%27%7D
-                        
&extra__google_cloud_platform__keyfile_path=asaa""".split())]}),
+                        
&extra__google_cloud_platform__keyfile_path=asaa""".split())}),
         )
     )
     def test_yaml_file_should_load_connection(self, file_content, 
expected_connection_uris):
         with mock_local_file(file_content):
-            connections_by_conn_id = 
local_filesystem.load_connections("a.yaml")
+            connections_by_conn_id = 
local_filesystem.load_connections_dict("a.yaml")
             connection_uris_by_conn_id = {
-                conn_id: [connection.get_uri() for connection in connections]
-                for conn_id, connections in connections_by_conn_id.items()
+                conn_id: connection.get_uri()
+                for conn_id, connection in connections_by_conn_id.items()
             }
 
             self.assertEqual(expected_connection_uris, 
connection_uris_by_conn_id)
 
     @parameterized.expand(
         (
-            ("""conn_c:
-               conn_type: scheme
-               host: host
-               schema: lschema
-               login: Login
-               password: None
-               port: 1234
-               extra_dejson:
-                 aws_conn_id: bbb
-                 region_name: ccc
-                 """, {"conn_c": [{"aws_conn_id": "bbb", "region_name": 
"ccc"}]}),
-            ("""conn_d:
-               conn_type: scheme
-               host: host
-               schema: lschema
-               login: Login
-               password: None
-               port: 1234
-               extra_dejson:
-                 extra__google_cloud_platform__keyfile_dict:
-                   a: b
-                 extra__google_cloud_platform__key_path: xxx
-                 """, {"conn_d": 
[{"extra__google_cloud_platform__keyfile_dict": {"a": "b"},
-                                   "extra__google_cloud_platform__key_path": 
"xxx"}]}),
-            ("""conn_d:
-               conn_type: scheme
-               host: host
-               schema: lschema
-               login: Login
-               password: None
-               port: 1234
-               extra: '{\"extra__google_cloud_platform__keyfile_dict\": 
{\"a\": \"b\"}}'""", {"conn_d": [
-                {"extra__google_cloud_platform__keyfile_dict": {"a": "b"}}]})
-
+            (
+                """
+                conn_c:
+                   conn_type: scheme
+                   host: host
+                   schema: lschema
+                   login: Login
+                   password: None
+                   port: 1234
+                   extra_dejson:
+                     aws_conn_id: bbb
+                     region_name: ccc
+                 """,
+                {"conn_c": {"aws_conn_id": "bbb", "region_name": "ccc"}},
+            ),
+            (
+                """
+                conn_d:
+                   conn_type: scheme
+                   host: host
+                   schema: lschema
+                   login: Login
+                   password: None
+                   port: 1234
+                   extra_dejson:
+                     extra__google_cloud_platform__keyfile_dict:
+                       a: b
+                     extra__google_cloud_platform__key_path: xxx
+                """,
+                {
+                    "conn_d": {
+                        "extra__google_cloud_platform__keyfile_dict": {"a": 
"b"},
+                        "extra__google_cloud_platform__key_path": "xxx",
+                    }
+                },
+            ),
+            (
+                """
+                conn_d:
+                   conn_type: scheme
+                   host: host
+                   schema: lschema
+                   login: Login
+                   password: None
+                   port: 1234
+                   extra: '{\"extra__google_cloud_platform__keyfile_dict\": 
{\"a\": \"b\"}}'
+                """,
+                {"conn_d": {"extra__google_cloud_platform__keyfile_dict": 
{"a": "b"}}},
+            ),
         )
     )
     def test_yaml_file_should_load_connection_extras(self, file_content, 
expected_extras):
         with mock_local_file(file_content):
-            connections_by_conn_id = 
local_filesystem.load_connections("a.yaml")
+            connections_by_conn_id = 
local_filesystem.load_connections_dict("a.yaml")
             connection_uris_by_conn_id = {
-                conn_id: [connection.extra_dejson for connection in 
connections]
-                for conn_id, connections in connections_by_conn_id.items()
+                conn_id: connection.extra_dejson for conn_id, connection in 
connections_by_conn_id.items()
             }
             self.assertEqual(expected_extras, connection_uris_by_conn_id)
 
@@ -298,7 +313,7 @@ class TestLoadConnection(unittest.TestCase):
     def test_yaml_invalid_extra(self, file_content, expected_message):
         with mock_local_file(file_content):
             with self.assertRaisesRegex(AirflowException, 
re.escape(expected_message)):
-                local_filesystem.load_connections("a.yaml")
+                local_filesystem.load_connections_dict("a.yaml")
 
     @parameterized.expand(
         (
@@ -308,7 +323,7 @@ class TestLoadConnection(unittest.TestCase):
     def test_ensure_unique_connection_env(self, file_content):
         with mock_local_file(file_content):
             with self.assertRaises(ConnectionNotUnique):
-                local_filesystem.load_connections("a.env")
+                local_filesystem.load_connections_dict("a.env")
 
     @parameterized.expand(
         (
@@ -323,7 +338,7 @@ class TestLoadConnection(unittest.TestCase):
     def test_ensure_unique_connection_json(self, file_content):
         with mock_local_file(json.dumps(file_content)):
             with self.assertRaises(ConnectionNotUnique):
-                local_filesystem.load_connections("a.json")
+                local_filesystem.load_connections_dict("a.json")
 
     @parameterized.expand(
         (
@@ -336,7 +351,7 @@ class TestLoadConnection(unittest.TestCase):
     def test_ensure_unique_connection_yaml(self, file_content):
         with mock_local_file(file_content):
             with self.assertRaises(ConnectionNotUnique):
-                local_filesystem.load_connections("a.yaml")
+                local_filesystem.load_connections_dict("a.yaml")
 
 
 class TestLocalFileBackend(unittest.TestCase):

Reply via email to