This is an automated email from the ASF dual-hosted git repository.

amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new a7120d28383 Move fernet utils in serde to remove core dependency 
(#60771)
a7120d28383 is described below

commit a7120d28383ec9a3dee0c78b0237982a9c4ee49e
Author: Amogh Desai <[email protected]>
AuthorDate: Tue Jan 20 19:20:06 2026 +0530

    Move fernet utils in serde to remove core dependency (#60771)
---
 task-sdk/src/airflow/sdk/crypto.py                 | 127 +++++++++++++++++++
 .../src/airflow/sdk/serde/serializers/deltalake.py |   4 +-
 .../src/airflow/sdk/serde/serializers/iceberg.py   |   4 +-
 task-sdk/tests/task_sdk/docs/test_public_api.py    |   1 +
 task-sdk/tests/task_sdk/serde/test_serializers.py  |   4 +
 task-sdk/tests/task_sdk/test_crypto.py             | 137 +++++++++++++++++++++
 6 files changed, 273 insertions(+), 4 deletions(-)

diff --git a/task-sdk/src/airflow/sdk/crypto.py 
b/task-sdk/src/airflow/sdk/crypto.py
new file mode 100644
index 00000000000..a4ac3c8b4f8
--- /dev/null
+++ b/task-sdk/src/airflow/sdk/crypto.py
@@ -0,0 +1,127 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import logging
+from functools import cache
+from typing import TYPE_CHECKING, Protocol
+
+log = logging.getLogger(__name__)
+
+if TYPE_CHECKING:
+    from cryptography.fernet import MultiFernet
+
+
+class FernetProtocol(Protocol):
+    """
+    Protocol for Fernet encryption/decryption.
+
+    This class is only used for TypeChecking (for IDEs, mypy, etc).
+
+    Note: The rotate() method exists on _RealFernet but is not part of this 
Protocol.
+    rotate() should only be called from CLI commands where encryption is 
guaranteed to be
+    enabled. _NullFernet (used in unit tests without FERNET_KEY) does not 
support rotation.
+    """
+
+    is_encrypted: bool
+
+    def decrypt(self, msg: bytes | str, ttl: int | None = None) -> bytes:
+        """Decrypt with Fernet."""
+        ...
+
+    def encrypt(self, msg: bytes) -> bytes:
+        """Encrypt with Fernet."""
+        ...
+
+
+class _NullFernet:
+    """
+    A "Null" encryptor class that doesn't encrypt or decrypt but that presents 
a similar interface to Fernet.
+
+    The purpose of this is to make the rest of the code not have to know the
+    difference, and to only display the message once, not 20 times when
+    `airflow db migrate` is run.
+    """
+
+    is_encrypted = False
+
+    def decrypt(self, msg: bytes | str, ttl: int | None = None) -> bytes:
+        """Decrypt with Fernet."""
+        if isinstance(msg, bytes):
+            return msg
+        if isinstance(msg, str):
+            return msg.encode("utf-8")
+        raise ValueError(f"Expected bytes or str, got {type(msg)}")
+
+    def encrypt(self, msg: bytes) -> bytes:
+        """Encrypt with Fernet."""
+        return msg
+
+
+class _RealFernet:
+    """
+    A wrapper around the real Fernet to set is_encrypted to True.
+
+    This class is only used internally to avoid changing the interface of
+    the get_fernet function.
+    """
+
+    is_encrypted = True
+
+    def __init__(self, fernet: MultiFernet):
+        self._fernet = fernet
+
+    def decrypt(self, msg: bytes | str, ttl: int | None = None) -> bytes:
+        """Decrypt with Fernet."""
+        return self._fernet.decrypt(msg, ttl)
+
+    def encrypt(self, msg: bytes) -> bytes:
+        """Encrypt with Fernet."""
+        return self._fernet.encrypt(msg)
+
+    def rotate(self, msg: bytes | str) -> bytes:
+        """Rotate the Fernet key for the given message."""
+        return self._fernet.rotate(msg)
+
+
+@cache
+def get_fernet() -> FernetProtocol:
+    """
+    Deferred load of Fernet key from SDK configuration.
+
+    This function could fail either because Cryptography is not installed
+    or because the Fernet key is invalid.
+
+    :return: Fernet object
+    :raises: airflow.sdk.exceptions.AirflowException if there's a problem 
trying to load Fernet
+    """
+    from cryptography.fernet import Fernet, MultiFernet
+
+    from airflow.sdk.configuration import conf
+    from airflow.sdk.exceptions import AirflowException
+
+    try:
+        fernet_key = conf.get("core", "FERNET_KEY")
+        if not fernet_key:
+            log.warning("empty cryptography key - values will not be stored 
encrypted.")
+            return _NullFernet()
+
+        fernet = MultiFernet([Fernet(fernet_part.encode("utf-8")) for 
fernet_part in fernet_key.split(",")])
+        return _RealFernet(fernet)
+    except (ValueError, TypeError) as value_error:
+        raise AirflowException(f"Could not create Fernet object: 
{value_error}")
diff --git a/task-sdk/src/airflow/sdk/serde/serializers/deltalake.py 
b/task-sdk/src/airflow/sdk/serde/serializers/deltalake.py
index 7999144e0e3..672ae33ae06 100644
--- a/task-sdk/src/airflow/sdk/serde/serializers/deltalake.py
+++ b/task-sdk/src/airflow/sdk/serde/serializers/deltalake.py
@@ -37,7 +37,7 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
     if not isinstance(o, DeltaTable):
         return "", "", 0, False
 
-    from airflow.models.crypto import get_fernet
+    from airflow.sdk.crypto import get_fernet
 
     # we encrypt the information here until we have as part of the
     # storage options can have sensitive information
@@ -58,7 +58,7 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
 def deserialize(cls: type, version: int, data: dict):
     from deltalake.table import DeltaTable
 
-    from airflow.models.crypto import get_fernet
+    from airflow.sdk.crypto import get_fernet
 
     if version > __version__:
         raise TypeError("serialized version is newer than class version")
diff --git a/task-sdk/src/airflow/sdk/serde/serializers/iceberg.py 
b/task-sdk/src/airflow/sdk/serde/serializers/iceberg.py
index ee2427836ff..b242ee6d4b5 100644
--- a/task-sdk/src/airflow/sdk/serde/serializers/iceberg.py
+++ b/task-sdk/src/airflow/sdk/serde/serializers/iceberg.py
@@ -37,7 +37,7 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
     if not isinstance(o, Table):
         return "", "", 0, False
 
-    from airflow.models.crypto import get_fernet
+    from airflow.sdk.crypto import get_fernet
 
     # we encrypt the catalog information here until we have
     # global catalog management in airflow and the properties
@@ -59,7 +59,7 @@ def deserialize(cls: type, version: int, data: dict):
     from pyiceberg.catalog import load_catalog
     from pyiceberg.table import Table
 
-    from airflow.models.crypto import get_fernet
+    from airflow.sdk.crypto import get_fernet
 
     if version > __version__:
         raise TypeError("serialized version is newer than class version")
diff --git a/task-sdk/tests/task_sdk/docs/test_public_api.py 
b/task-sdk/tests/task_sdk/docs/test_public_api.py
index 2186cd32b12..f53887ec5c9 100644
--- a/task-sdk/tests/task_sdk/docs/test_public_api.py
+++ b/task-sdk/tests/task_sdk/docs/test_public_api.py
@@ -61,6 +61,7 @@ def test_airflow_sdk_no_unexpected_exports():
         "observability",
         "plugins_manager",
         "listener",
+        "crypto",
         "providers_manager_runtime",
     }
     unexpected = actual - public - ignore
diff --git a/task-sdk/tests/task_sdk/serde/test_serializers.py 
b/task-sdk/tests/task_sdk/serde/test_serializers.py
index f8a57caa50f..a090e8d6b13 100644
--- a/task-sdk/tests/task_sdk/serde/test_serializers.py
+++ b/task-sdk/tests/task_sdk/serde/test_serializers.py
@@ -29,6 +29,7 @@ import pandas as pd
 import pendulum
 import pendulum.tz
 import pytest
+from cryptography.fernet import Fernet
 from dateutil.tz import tzutc
 from kubernetes.client import models as k8s
 from packaging import version
@@ -42,6 +43,7 @@ from airflow.sdk.definitions.param import Param, ParamsDict
 from airflow.sdk.serde import CLASSNAME, DATA, VERSION, decode, deserialize, 
serialize
 from airflow.sdk.serde.serializers import builtin
 
+from tests_common.test_utils.config import conf_vars
 from tests_common.test_utils.markers import 
skip_if_force_lowest_dependencies_marker
 
 PENDULUM3 = version.parse(metadata.version("pendulum")).major == 3
@@ -284,6 +286,7 @@ class TestSerializers:
         with pytest.raises(TypeError, match=msg):
             deserialize(klass, version, data)
 
+    @conf_vars({("core", "fernet_key"): Fernet.generate_key().decode()})
     def test_iceberg(self):
         pytest.importorskip("pyiceberg", minversion="2.0.0")
         from pyiceberg.catalog import Catalog
@@ -310,6 +313,7 @@ class TestSerializers:
         mock_load_catalog.assert_called_with("catalog", uri=uri)
         mock_load_table.assert_called_with((identifier[1], identifier[2]))
 
+    @conf_vars({("core", "fernet_key"): Fernet.generate_key().decode()})
     def test_deltalake(self):
         deltalake = pytest.importorskip("deltalake")
 
diff --git a/task-sdk/tests/task_sdk/test_crypto.py 
b/task-sdk/tests/task_sdk/test_crypto.py
new file mode 100644
index 00000000000..b1acbb5e718
--- /dev/null
+++ b/task-sdk/tests/task_sdk/test_crypto.py
@@ -0,0 +1,137 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pytest
+from cryptography.fernet import Fernet
+
+from airflow.sdk.crypto import _NullFernet, _RealFernet, get_fernet
+from airflow.sdk.exceptions import AirflowException
+
+from tests_common.test_utils.config import conf_vars
+
+
+class TestNullFernet:
+    def test_decryption_of_invalid_type(self):
+        """Should raise ValueError for non-string/bytes input."""
+        null_fernet = _NullFernet()
+        with pytest.raises(ValueError, match="Expected bytes or str, got 
<class 'int'>"):
+            null_fernet.decrypt(123)  # type: ignore[arg-type]
+
+    def test_encrypt_decrypt_roundtrip(self):
+        """Should preserve data through encrypt/decrypt cycle."""
+        null_fernet = _NullFernet()
+        original = b"test message"
+        encrypted = null_fernet.encrypt(original)
+        decrypted = null_fernet.decrypt(encrypted)
+        assert decrypted == original
+        assert encrypted == original
+
+
+class TestRealFernet:
+    def test_encryption(self):
+        """Encryption should produce encrypted output different from input."""
+        from cryptography.fernet import MultiFernet
+
+        key = Fernet.generate_key()
+        real_fernet = _RealFernet(MultiFernet([Fernet(key)]))
+
+        msg = b"secret message"
+        encrypted = real_fernet.encrypt(msg)
+
+        assert encrypted != msg
+        assert real_fernet.decrypt(encrypted) == msg
+
+    def test_rotate_reencrypt_with_primary_key(self):
+        """rotate() should re-encrypt data with the primary key."""
+        from cryptography.fernet import MultiFernet
+
+        key1 = Fernet.generate_key()
+        key2 = Fernet.generate_key()
+
+        # encrypt with key1 only
+        encrypted_with_key1 = Fernet(key1).encrypt(b"rotate test")
+
+        # MultiFernet with key2 as primary, key1 as fallback
+        multi = MultiFernet([Fernet(key2), Fernet(key1)])
+        real_fernet = _RealFernet(multi)
+
+        # rotate should re-encrypt with key2
+        rotated = real_fernet.rotate(encrypted_with_key1)
+
+        # key2 should be able to decrypt
+        assert Fernet(key2).decrypt(rotated) == b"rotate test"
+        assert rotated != encrypted_with_key1
+
+
+class TestGetFernet:
+    @conf_vars({("core", "FERNET_KEY"): ""})
+    def test_empty_key(self):
+        get_fernet.cache_clear()
+        fernet = get_fernet()
+
+        assert not fernet.is_encrypted
+        test_data = b"unencrypted"
+        assert fernet.encrypt(test_data) == test_data
+        assert fernet.decrypt(test_data) == test_data
+
+    @conf_vars({("core", "FERNET_KEY"): Fernet.generate_key().decode()})
+    def test_valid_key_encrypts_data(self):
+        """Valid FERNET_KEY should return working encryption."""
+        get_fernet.cache_clear()
+        fernet = get_fernet()
+
+        assert fernet.is_encrypted
+        original = b"sensitive data"
+        encrypted = fernet.encrypt(original)
+        assert encrypted != original
+        assert fernet.decrypt(encrypted) == original
+
+    @conf_vars({("core", "FERNET_KEY"): "invalid-key"})
+    def test_invalid_key(self):
+        """Invalid FERNET_KEY should raise a AirflowException."""
+        get_fernet.cache_clear()
+        with pytest.raises(AirflowException, match="Could not create Fernet 
object"):
+            get_fernet()
+
+    def test_multiple_keys(self):
+        """Multiple comma separated keys should support key rotation."""
+        key1 = Fernet.generate_key()
+        key2 = Fernet.generate_key()
+
+        # encrypt with key1
+        data_encrypted_with_key1 = Fernet(key1).encrypt(b"secret data")
+
+        # get_fernet with both keys (key2 primary, key1 fallback)
+        with conf_vars({("core", "FERNET_KEY"): 
f"{key2.decode()},{key1.decode()}"}):
+            get_fernet.cache_clear()
+            fernet = get_fernet()
+
+            # decrypt data encrypted with old key1
+            assert fernet.decrypt(data_encrypted_with_key1) == b"secret data"
+            new_encrypted = fernet.encrypt(b"new")
+            assert Fernet(key2).decrypt(new_encrypted) == b"new"
+
+    @conf_vars({("core", "FERNET_KEY"): Fernet.generate_key().decode()})
+    def test_caching(self):
+        """get_fernet() should return cached instance."""
+        get_fernet.cache_clear()
+
+        fernet1 = get_fernet()
+        fernet2 = get_fernet()
+
+        assert fernet1 is fernet2

Reply via email to