stale[bot] closed pull request #3805: [AIRFLOW-2062] Add per-connection KMS 
encryption.
URL: https://github.com/apache/incubator-airflow/pull/3805
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py
index 1c5494ead1..15b061c94c 100644
--- a/airflow/bin/cli.py
+++ b/airflow/bin/cli.py
@@ -1133,7 +1133,8 @@ def version(args):  # noqa
 
 
 alternative_conn_specs = ['conn_type', 'conn_host',
-                          'conn_login', 'conn_password', 'conn_schema', 
'conn_port']
+                          'conn_login', 'conn_password', 'conn_schema', 
'conn_port',
+                          'kms_conn_id', 'kms_extra']
 
 
 @cli_utils.action_logging
@@ -1235,7 +1236,10 @@ def connections(args):
             return
 
         if args.conn_uri:
-            new_conn = Connection(conn_id=args.conn_id, uri=args.conn_uri)
+            new_conn = Connection(conn_id=args.conn_id,
+                                  uri=args.conn_uri,
+                                  kms_conn_id=args.kms_conn_id,
+                                  kms_extra=args.kms_extra)
         else:
             new_conn = Connection(conn_id=args.conn_id,
                                   conn_type=args.conn_type,
@@ -1243,7 +1247,10 @@ def connections(args):
                                   login=args.conn_login,
                                   password=args.conn_password,
                                   schema=args.conn_schema,
-                                  port=args.conn_port)
+                                  port=args.conn_port,
+                                  kms_conn_id=args.kms_conn_id,
+                                  kms_extra=args.kms_extra
+                                  )
         if args.conn_extra is not None:
             new_conn.set_extra(args.conn_extra)
 
@@ -1883,6 +1890,15 @@ class CLIFactory(object):
             ('--conn_extra',),
             help='Connection `Extra` field, optional when adding a connection',
             type=str),
+        'kms_conn_id': Arg(
+            ('--kms_conn_id',),
+            help='An existing connection to use when encrypting this 
connection with a '
+                 'KMS, optional when adding a connection',
+            type=str),
+        'kms_extra': Arg(
+            ('--kms_extra',),
+            help='Connection `KMS Extra` field, optional when adding a 
connection',
+            type=str),
         # users
         'username': Arg(
             ('--username',),
diff --git a/airflow/contrib/hooks/gcp_kms_hook.py 
b/airflow/contrib/hooks/gcp_kms_hook.py
index 6f2b3aedff..63e35fbe89 100644
--- a/airflow/contrib/hooks/gcp_kms_hook.py
+++ b/airflow/contrib/hooks/gcp_kms_hook.py
@@ -20,6 +20,7 @@
 
 import base64
 
+from airflow.hooks.kmsapi_hook import KmsApiHook
 from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook
 
 from apiclient.discovery import build
@@ -35,7 +36,7 @@ def _b64decode(s):
     return base64.b64decode(s.encode('utf-8'))
 
 
-class GoogleCloudKMSHook(GoogleCloudBaseHook):
+class GoogleCloudKMSHook(GoogleCloudBaseHook, KmsApiHook):
     """
     Interact with Google Cloud KMS. This hook uses the Google Cloud Platform
     connection.
@@ -106,3 +107,17 @@ def decrypt(self, key_name, ciphertext, 
authenticated_data=None):
 
         plaintext = _b64decode(response['plaintext'])
         return plaintext
+
+    def encrypt_conn_key(self, connection):
+        kms_extras = connection.kms_extra_dejson
+        key_name = kms_extras['kms_extra__google_cloud_platform__key_name']
+        conn_key = connection._plain_conn_key
+
+        connection.conn_key = self.encrypt(key_name, conn_key)
+
+    def decrypt_conn_key(self, connection):
+        kms_extras = connection.kms_extra_dejson
+        key_name = kms_extras['kms_extra__google_cloud_platform__key_name']
+        conn_key = connection.conn_key
+
+        connection._plain_conn_key = self.decrypt(key_name, conn_key)
diff --git a/airflow/hooks/base_hook.py b/airflow/hooks/base_hook.py
index 103fa6260b..fe663f61c1 100644
--- a/airflow/hooks/base_hook.py
+++ b/airflow/hooks/base_hook.py
@@ -22,16 +22,9 @@
 from __future__ import print_function
 from __future__ import unicode_literals
 
-import os
-import random
-
 from airflow.models import Connection
-from airflow.exceptions import AirflowException
-from airflow.utils.db import provide_session
 from airflow.utils.log.logging_mixin import LoggingMixin
 
-CONN_ENV_PREFIX = 'AIRFLOW_CONN_'
-
 
 class BaseHook(LoggingMixin):
     """
@@ -44,48 +37,9 @@ class BaseHook(LoggingMixin):
     def __init__(self, source):
         pass
 
-    @classmethod
-    @provide_session
-    def _get_connections_from_db(cls, conn_id, session=None):
-        db = (
-            session.query(Connection)
-            .filter(Connection.conn_id == conn_id)
-            .all()
-        )
-        session.expunge_all()
-        if not db:
-            raise AirflowException(
-                "The conn_id `{0}` isn't defined".format(conn_id))
-        return db
-
-    @classmethod
-    def _get_connection_from_env(cls, conn_id):
-        environment_uri = os.environ.get(CONN_ENV_PREFIX + conn_id.upper())
-        conn = None
-        if environment_uri:
-            conn = Connection(conn_id=conn_id, uri=environment_uri)
-        return conn
-
-    @classmethod
-    def get_connections(cls, conn_id):
-        conn = cls._get_connection_from_env(conn_id)
-        if conn:
-            conns = [conn]
-        else:
-            conns = cls._get_connections_from_db(conn_id)
-        return conns
-
-    @classmethod
-    def get_connection(cls, conn_id):
-        conn = random.choice(cls.get_connections(conn_id))
-        if conn.host:
-            log = LoggingMixin().log
-            log.info("Using connection to: %s", conn.host)
-        return conn
-
     @classmethod
     def get_hook(cls, conn_id):
-        connection = cls.get_connection(conn_id)
+        connection = Connection.get_connection(conn_id)
         return connection.get_hook()
 
     def get_conn(self):
@@ -99,3 +53,11 @@ def get_pandas_df(self, sql):
 
     def run(self, sql):
         raise NotImplementedError()
+
+    @classmethod
+    def get_connections(cls, conn_id):
+        return Connection.get_connections(conn_id)
+
+    @classmethod
+    def get_connection(cls, conn_id):
+        return Connection.get_connection(conn_id)
diff --git a/airflow/hooks/kmsapi_hook.py b/airflow/hooks/kmsapi_hook.py
new file mode 100644
index 0000000000..cbae89d471
--- /dev/null
+++ b/airflow/hooks/kmsapi_hook.py
@@ -0,0 +1,45 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import absolute_import
+from __future__ import unicode_literals
+
+from airflow.hooks.base_hook import BaseHook
+
+
+class KmsApiHook(BaseHook):
+    """
+    Abstract base class for KMS hooks. KMS hooks should support encryption
+    and decryption services. In addition, a KMS hook should support encrypting
+    and decrypting connection keys.
+    """
+
+    def encrypt_conn_key(self, connection):
+        """
+        Accepts a Connection object and sets `connection.conn_key` by
+        encrypting `connection._plain_conn_key` via the KMS.
+        """
+        raise NotImplementedError()
+
+    def decrypt_conn_key(self, connection):
+        """
+        Accepts a Connection object and sets `connection._plain_conn_key` by
+        decrypting `connection.conn_key` via the KMS.
+        """
+        raise NotImplementedError()
diff --git 
a/airflow/migrations/versions/d7a2586b258a_add_kms_fields_to_connection.py 
b/airflow/migrations/versions/d7a2586b258a_add_kms_fields_to_connection.py
new file mode 100644
index 0000000000..d2fb2be1af
--- /dev/null
+++ b/airflow/migrations/versions/d7a2586b258a_add_kms_fields_to_connection.py
@@ -0,0 +1,49 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""add KMS fields to connection
+
+Revision ID: d7a2586b258a
+Revises: 9635ae0956e7
+Create Date: 2018-07-19 09:13:46.044044
+
+"""
+
+from alembic import op
+import sqlalchemy as sa
+
+# revision identifiers, used by Alembic.
+revision = 'd7a2586b258a'
+down_revision = '9635ae0956e7'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    op.add_column('connection', sa.Column('conn_key', sa.String(length=200),
+                                          nullable=True))
+    op.add_column('connection', sa.Column('_kms_conn_id', 
sa.String(length=250),
+                                          nullable=True))
+    op.add_column('connection', sa.Column('_kms_extra', sa.String(length=5000),
+                                          nullable=True))
+
+
+def downgrade():
+    op.drop_column('connection', 'conn_key')
+    op.drop_column('connection', '_kms_conn_id')
+    op.drop_column('connection', '_kms_extra')
diff --git a/airflow/models.py b/airflow/models.py
index e4a50bc476..4c87833d18 100755
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -55,6 +55,7 @@
 import traceback
 import warnings
 import hashlib
+import random
 
 import uuid
 from datetime import datetime
@@ -118,7 +119,7 @@ class InvalidFernetToken(Exception):
     pass
 
 
-class NullFernet(object):
+class NullFernet(LoggingMixin):
     """
     A "Null" encryptor class that doesn't encrypt or decrypt but that presents
     a similar interface to Fernet.
@@ -129,6 +130,14 @@ class NullFernet(object):
     """
     is_encrypted = False
 
+    def __init__(self, k):
+        self.log.warn(
+            "cryptography not found - values will not be stored encrypted.", 
exc_info=1)
+
+    @classmethod
+    def generate_key():
+        return ""
+
     def decrpyt(self, b):
         return b
 
@@ -136,6 +145,13 @@ def encrypt(self, b):
         return b
 
 
+try:
+    from cryptography.fernet import Fernet, InvalidToken
+    InvalidFernetToken = InvalidToken  # noqa: F811
+except BuiltinImportError:
+    Fernet = NullFernet
+
+
 _fernet = None
 
 
@@ -143,8 +159,7 @@ def get_fernet():
     """
     Deferred load of Fernet key.
 
-    This function could fail either because Cryptography is not installed
-    or because the Fernet key is invalid.
+    This function could fail because the Fernet key is invalid.
 
     :return: Fernet object
     :raises: AirflowException if there's a problem trying to load Fernet
@@ -152,21 +167,10 @@ def get_fernet():
     global _fernet
     if _fernet:
         return _fernet
-    try:
-        from cryptography.fernet import Fernet, InvalidToken
-        global InvalidFernetToken
-        InvalidFernetToken = InvalidToken
-
-    except BuiltinImportError:
-        LoggingMixin().log.warn("cryptography not found - values will not be 
stored "
-                                "encrypted.",
-                                exc_info=1)
-        _fernet = NullFernet()
-        return _fernet
 
     try:
         _fernet = Fernet(configuration.conf.get('core', 
'FERNET_KEY').encode('utf-8'))
-        _fernet.is_encrypted = True
+        _fernet.is_encrypted = (Fernet != NullFernet)
         return _fernet
     except (ValueError, TypeError) as ve:
         raise AirflowException("Could not create Fernet object: {}".format(ve))
@@ -631,6 +635,8 @@ class Connection(Base, LoggingMixin):
     """
     __tablename__ = "connection"
 
+    _conn_env_prefix = 'AIRFLOW_CONN_'
+
     id = Column(Integer(), primary_key=True)
     conn_id = Column(String(ID_LEN))
     conn_type = Column(String(500))
@@ -642,6 +648,11 @@ class Connection(Base, LoggingMixin):
     is_encrypted = Column(Boolean, unique=False, default=False)
     is_extra_encrypted = Column(Boolean, unique=False, default=False)
     _extra = Column('extra', String(5000))
+    conn_key = Column(String(200))
+    _kms_conn_id = Column(String(ID_LEN))
+    _kms_extra = Column(String(5000))
+
+    _plain_conn_key = None
 
     _types = [
         ('docker', 'Docker Registry',),
@@ -683,8 +694,15 @@ def __init__(
             self, conn_id=None, conn_type=None,
             host=None, login=None, password=None,
             schema=None, port=None, extra=None,
+            conn_key=None, kms_conn_id=None, kms_extra=None,
             uri=None):
         self.conn_id = conn_id
+
+        # KMS first, so that later properties are correctly encrypted.
+        self.conn_key = conn_key
+        self._kms_conn_id = kms_conn_id
+        self._kms_extra = kms_extra
+
         if uri:
             self.parse_from_uri(uri)
         else:
@@ -696,6 +714,19 @@ def __init__(
             self.port = port
             self.extra = extra
 
+    def _init_keys(self):
+        kms_conn = Connection.get_connection(self.kms_conn_id)
+        if kms_conn.kms_conn_id:
+            raise AirflowException("Can't chain KMS encrypted Connections.")
+
+        kms_hook = kms_conn.get_kms_hook()
+
+        if self.conn_key:
+            kms_hook.decrypt_conn_key(self)
+        else:
+            self._plain_conn_key = Fernet.generate_key()
+            kms_hook.encrypt_conn_key(self)
+
     def parse_from_uri(self, uri):
         temp_uri = urlparse(uri)
         hostname = temp_uri.hostname or ''
@@ -715,7 +746,7 @@ def parse_from_uri(self, uri):
 
     def get_password(self):
         if self._password and self.is_encrypted:
-            fernet = get_fernet()
+            fernet = self._get_fernet()
             if not fernet.is_encrypted:
                 raise AirflowException(
                     "Can't decrypt encrypted password for login={}, \
@@ -726,7 +757,7 @@ def get_password(self):
 
     def set_password(self, value):
         if value:
-            fernet = get_fernet()
+            fernet = self._get_fernet()
             self._password = fernet.encrypt(bytes(value, 'utf-8')).decode()
             self.is_encrypted = fernet.is_encrypted
 
@@ -737,7 +768,7 @@ def password(cls):
 
     def get_extra(self):
         if self._extra and self.is_extra_encrypted:
-            fernet = get_fernet()
+            fernet = self._get_fernet()
             if not fernet.is_encrypted:
                 raise AirflowException(
                     "Can't decrypt `extra` params for login={},\
@@ -748,7 +779,7 @@ def get_extra(self):
 
     def set_extra(self, value):
         if value:
-            fernet = get_fernet()
+            fernet = self._get_fernet()
             self._extra = fernet.encrypt(bytes(value, 'utf-8')).decode()
             self.is_extra_encrypted = fernet.is_encrypted
         else:
@@ -760,6 +791,44 @@ def extra(cls):
         return synonym('_extra',
                        descriptor=property(cls.get_extra, cls.set_extra))
 
+    @property
+    def kms_conn_id(self):
+        return self._kms_conn_id
+
+    @property
+    def kms_extra(self):
+        return self._kms_extra
+
+    def update_kms(self, kms_conn_id=None, kms_extra=None, clear=False):
+        """
+        Updates a KMS-encrypted connection to use a new set of KMS credentials.
+        This prevents broken access to encrypted fields when updating either or
+        both KMS fields.
+
+        :param kms_conn_id: The new KMS connection ID to use.
+        :type kms_conn_id: str
+        :param kms_extra: The new KMS extra field string to use.
+        :type kms_extra: str
+        :param clear: If True, allows `None` values in the other fields to 
clear those
+                      fields. Default is to preserve existing fields.
+        :type clear: bool
+        """
+        # Decrypt the current encrypted fields.
+        password = self.password
+        extra = self.extra
+        # Throw out the old key.
+        self.conn_key = None
+        self._plain_conn_key = None
+        if clear or kms_conn_id:
+            self._kms_conn_id = kms_conn_id
+        if clear or kms_extra:
+            self._kms_extra = kms_extra
+        if self.kms_conn_id:
+            self._init_keys()
+        # Re-encrypt the sensitive fields.
+        self.password = password
+        self.extra = extra
+
     def get_hook(self):
         try:
             if self.conn_type == 'mysql':
@@ -819,6 +888,19 @@ def get_hook(self):
         except Exception:
             pass
 
+    def get_kms_hook(self):
+        """
+        Returns a KMS Hook that uses this connection, if this connection's 
type has
+        an associated KMS Hook.
+
+        :rtype: KmsApiHook
+        """
+        if self.conn_type == 'google_cloud_platform':
+            from airflow.contrib.hooks.gcp_kms_hook import GoogleCloudKMSHook
+            return GoogleCloudKMSHook(gcp_conn_id=self.conn_id)
+        else:
+            raise ValueError("No KMS hook for conn_type 
{0}".format(self.conn_type))
+
     def __repr__(self):
         return self.conn_id
 
@@ -835,6 +917,74 @@ def extra_dejson(self):
 
         return obj
 
+    @property
+    def kms_extra_dejson(self):
+        """Returns the kms_extra property by deserializing json."""
+        obj = {}
+        if self.kms_extra:
+            try:
+                obj = json.loads(self.kms_extra)
+            except Exception as e:
+                self.log.exception(e)
+                self.log.error("Failed parsing the json for conn_id %s", 
self.conn_id)
+
+        return obj
+
+    @classmethod
+    @provide_session
+    def _get_connections_from_db(cls, conn_id, session=None):
+        db = (
+            session.query(Connection)
+            .filter(Connection.conn_id == conn_id)
+            .all()
+        )
+        session.expunge_all()
+        if not db:
+            raise AirflowException(
+                "The conn_id `{0}` isn't defined".format(conn_id))
+        return db
+
+    @classmethod
+    def _get_connection_from_env(cls, conn_id):
+        environment_uri = os.environ.get(cls._conn_env_prefix + 
conn_id.upper())
+        conn = None
+        if environment_uri:
+            conn = Connection(conn_id=conn_id, uri=environment_uri)
+        return conn
+
+    @classmethod
+    def get_connections(cls, conn_id):
+        conn = cls._get_connection_from_env(conn_id)
+        if conn:
+            conns = [conn]
+        else:
+            conns = cls._get_connections_from_db(conn_id)
+        return conns
+
+    @classmethod
+    def get_connection(cls, conn_id):
+        conn = random.choice(cls.get_connections(conn_id))
+        if conn.host:
+            log = LoggingMixin().log
+            log.info("Using connection to: %s", conn.host)
+        return conn
+
+    def _get_fernet(self):
+        global_fernet = get_fernet()
+
+        if self.kms_conn_id and not self._plain_conn_key:
+            self._init_keys()
+
+        if self._plain_conn_key:
+            try:
+                conn_fernet = Fernet(self._plain_conn_key)
+                conn_fernet.is_encrypted = True
+                return conn_fernet
+            except (ValueError, TypeError) as ve:
+                raise AirflowException("Could not create Fernet object: 
{}".format(ve))
+        else:
+            return global_fernet
+
 
 class DagPickle(Base):
     """
diff --git a/docs/howto/manage-connections.rst 
b/docs/howto/manage-connections.rst
index f869a08b3c..c7b7a5f23d 100644
--- a/docs/howto/manage-connections.rst
+++ b/docs/howto/manage-connections.rst
@@ -28,7 +28,9 @@ to create a new connection.
 2. Choose the connection type with the ``Conn Type`` field.
 3. Fill in the remaining fields. See
    :ref:`manage-connections-connection-types` for a description of the fields
-   belonging to the different connection types.
+   belonging to the different connection types, and 
+   :ref:`secure-connections-kms` if you want to encrypt connection
+   credentials using a Key Management System.
 4. Click the ``Save`` button to create the connection.
 
 Editing a Connection with the UI
diff --git a/docs/howto/secure-connections.rst 
b/docs/howto/secure-connections.rst
index bb13b1bb08..21e6b3d127 100644
--- a/docs/howto/secure-connections.rst
+++ b/docs/howto/secure-connections.rst
@@ -30,3 +30,16 @@ variable over the value in ``airflow.cfg``:
 
 4. Restart Airflow webserver.
 5. For existing connections (the ones that you had defined before installing 
``airflow[crypto]`` and creating a Fernet key), you need to open each 
connection in the connection admin UI, re-type the password, and save it.
+
+.. _secure-connections-kms:
+
+Encrypting Connections with a KMS
+---------------------------------
+
+In addition to using a fernet_key, Airflow supports encrypting connections 
using keys 
+managed by a Key Management System. In order to encrypt a connection using a 
KMS, follow these steps:
+
+1. Setup or create an account on a KMS of your choice. Create at least one 
encryption/decryption key.
+2. Install crypto package ``pip install apache-airflow[crypto]``
+3. Create a regular (non-KMS encrypted) connection defining how to connect to 
your KMS.
+4. Create (or modify) the connection you wish to encrypt. Choose the 
connection from step 3 for the  "KMS Conn Id" field, and fill out the any 
additional fields that appear.
\ No newline at end of file
diff --git a/tests/contrib/hooks/test_gcp_kms_hook.py 
b/tests/contrib/hooks/test_gcp_kms_hook.py
index eabf20e564..b5fa0bc465 100644
--- a/tests/contrib/hooks/test_gcp_kms_hook.py
+++ b/tests/contrib/hooks/test_gcp_kms_hook.py
@@ -158,3 +158,27 @@ def test_decrypt_authdata(self, mock_service):
                                           body=body)
         execute_method.assert_called_with()
         self.assertEqual(plaintext, ret_val)
+
+    @mock.patch(KMS_STRING.format('GoogleCloudKMSHook.encrypt'))
+    def test_encrypt_conn(self, mock_encrypt):
+        conn_key = "Test Key"
+        mock_conn = mock.Mock()
+        mock_conn._plain_conn_key = conn_key
+        mock_conn.kms_extra_dejson = {
+            "kms_extra__google_cloud_platform__key_name": TEST_KEY_ID}
+
+        self.kms_hook.encrypt_conn_key(mock_conn)
+
+        mock_encrypt.assert_called_with(TEST_KEY_ID, conn_key)
+
+    @mock.patch(KMS_STRING.format('GoogleCloudKMSHook.decrypt'))
+    def test_decrypt_conn(self, mock_decrypt):
+        conn_key = "Test Key"
+        mock_conn = mock.Mock()
+        mock_conn.conn_key = conn_key
+        mock_conn.kms_extra_dejson = {
+            "kms_extra__google_cloud_platform__key_name": TEST_KEY_ID}
+
+        self.kms_hook.decrypt_conn_key(mock_conn)
+
+        mock_decrypt.assert_called_with(TEST_KEY_ID, conn_key)
diff --git a/tests/models.py b/tests/models.py
index 999b1be1bb..87f2c8a50c 100644
--- a/tests/models.py
+++ b/tests/models.py
@@ -62,6 +62,32 @@
     os.path.dirname(os.path.realpath(__file__)), 'dags')
 
 
+class FernetTest(unittest.TestCase):
+    def setUp(self):
+        configuration.load_test_config()
+        models._fernet = None
+
+    def tearDown(self):
+        models._fernet = None
+        configuration.load_test_config()
+
+    def test_get_fernet(self):
+        test_fernet = models.get_fernet()
+        self.assertTrue(test_fernet.is_encrypted)
+        self.assertIs(test_fernet, models.get_fernet())
+
+    def test_get_fernet_bad_key(self):
+        too_short_key = ""
+        configuration.conf.set('core', 'FERNET_KEY', too_short_key)
+        with self.assertRaises(AirflowException):
+            models.get_fernet()
+
+    @patch('airflow.models.Fernet', new=models.NullFernet)
+    def test_get_fernet_no_fernet(self):
+        test_fernet = models.get_fernet()
+        self.assertFalse(test_fernet.is_encrypted)
+
+
 class DagTest(unittest.TestCase):
 
     def test_params_not_passed_is_empty_dict(self):
@@ -2696,3 +2722,94 @@ def test_connection_from_uri_with_extras(self):
         self.assertEqual(connection.port, 1234)
         self.assertDictEqual(connection.extra_dejson, {'extra1': 'a value',
                                                        'extra2': '/path/'})
+
+    @patch.object(Connection, 'get_connection')
+    def test_connection_kms_encryption(self, mock_get_conn):
+        from mock import Mock
+        mock_conn = Mock()
+        mock_hook = Mock()
+
+        mock_conn.kms_conn_id = None  # KMS conn is not itself encrypted
+        mock_get_conn.return_value = mock_conn
+        mock_conn.get_kms_hook.return_value = mock_hook
+
+        password = 'test_pass'
+        extra = 'test_extra'
+        kms_conn_id = 'test_kms_id'
+        connection = Connection(password=password, extra=extra, 
kms_conn_id=kms_conn_id)
+
+        self.assertEqual(connection.password, password)
+        self.assertEqual(connection.extra, extra)
+        mock_get_conn.assert_called_with(kms_conn_id)
+        mock_conn.get_kms_hook.assert_called_with()
+        mock_hook.encrypt_conn_key.assert_called_with(connection)
+
+    @patch.object(Connection, 'get_connection')
+    def test_connection_from_uri_kms_encryption(self, mock_get_conn):
+        from mock import Mock
+        mock_conn = Mock()
+        mock_hook = Mock()
+
+        mock_conn.kms_conn_id = None  # KMS conn is not itself encrypted
+        mock_get_conn.return_value = mock_conn
+        mock_conn.get_kms_hook.return_value = mock_hook
+
+        kms_conn_id = 'test_kms_id'
+        uri = 'scheme://user:password@host%2flocation:1234/schema?'\
+            'extra1=a%20value&extra2=%2fpath%2f'
+        connection = Connection(uri=uri, kms_conn_id=kms_conn_id)
+
+        self.assertEqual(connection.conn_type, 'scheme')
+        self.assertEqual(connection.host, 'host/location')
+        self.assertEqual(connection.schema, 'schema')
+        self.assertEqual(connection.login, 'user')
+        self.assertEqual(connection.password, 'password')
+        self.assertEqual(connection.port, 1234)
+        self.assertDictEqual(connection.extra_dejson, {'extra1': 'a value',
+                                                       'extra2': '/path/'})
+        mock_get_conn.assert_called_with(kms_conn_id)
+        mock_conn.get_kms_hook.assert_called_with()
+        mock_hook.encrypt_conn_key.assert_called_with(connection)
+
+    def test_get_kms_hook_missing(self):
+        test_conn_type = "Missing KMS Type"
+        test_connection = Connection(conn_type=test_conn_type)
+        with self.assertRaises(ValueError):
+            test_connection.get_kms_hook()
+
+    @patch.object(Connection, 'get_connection')
+    def test_update_kms(self, mock_get_conn):
+        from mock import Mock
+        mock_conn = Mock()
+        mock_hook = Mock()
+
+        mock_conn.kms_conn_id = None  # KMS conn is not itself encrypted
+        mock_get_conn.return_value = mock_conn
+        mock_conn.get_kms_hook.return_value = mock_hook
+
+        password = 'test_pass'
+        extra = 'test_extra'
+        kms_conn_id = 'test_kms_id'
+        kms_extra = 'test_kms_extra'
+        connection = Connection(password=password, extra=extra, 
kms_conn_id=kms_conn_id,
+                                kms_extra=kms_extra)
+
+        new_kms_conn_id = 'test_kms_id_2'
+        new_kms_extra = 'test_kms_extra_2'
+        first_conn_key = connection._plain_conn_key
+        connection.update_kms(kms_conn_id=new_kms_conn_id, 
kms_extra=new_kms_extra)
+
+        self.assertEqual(connection.password, password)
+        self.assertEqual(connection.extra, extra)
+        self.assertEqual(connection.kms_conn_id, new_kms_conn_id)
+        self.assertEqual(connection.kms_extra, new_kms_extra)
+        self.assertNotEqual(connection._plain_conn_key, first_conn_key)
+        mock_get_conn.assert_called_with(new_kms_conn_id)
+        mock_conn.get_kms_hook.assert_called_with()
+        mock_hook.encrypt_conn_key.assert_called_with(connection)
+
+        connection.update_kms(clear=True)
+        self.assertEqual(connection.password, password)
+        self.assertEqual(connection.extra, extra)
+        self.assertIsNone(connection.kms_conn_id)
+        self.assertIsNone(connection.kms_extra)


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to