This is an automated email from the ASF dual-hosted git repository.
amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a7120d28383 Move fernet utils in serde to remove core dependency
(#60771)
a7120d28383 is described below
commit a7120d28383ec9a3dee0c78b0237982a9c4ee49e
Author: Amogh Desai <[email protected]>
AuthorDate: Tue Jan 20 19:20:06 2026 +0530
Move fernet utils in serde to remove core dependency (#60771)
---
task-sdk/src/airflow/sdk/crypto.py | 127 +++++++++++++++++++
.../src/airflow/sdk/serde/serializers/deltalake.py | 4 +-
.../src/airflow/sdk/serde/serializers/iceberg.py | 4 +-
task-sdk/tests/task_sdk/docs/test_public_api.py | 1 +
task-sdk/tests/task_sdk/serde/test_serializers.py | 4 +
task-sdk/tests/task_sdk/test_crypto.py | 137 +++++++++++++++++++++
6 files changed, 273 insertions(+), 4 deletions(-)
diff --git a/task-sdk/src/airflow/sdk/crypto.py
b/task-sdk/src/airflow/sdk/crypto.py
new file mode 100644
index 00000000000..a4ac3c8b4f8
--- /dev/null
+++ b/task-sdk/src/airflow/sdk/crypto.py
@@ -0,0 +1,127 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import logging
+from functools import cache
+from typing import TYPE_CHECKING, Protocol
+
+log = logging.getLogger(__name__)
+
+if TYPE_CHECKING:
+ from cryptography.fernet import MultiFernet
+
+
+class FernetProtocol(Protocol):
+ """
+ Protocol for Fernet encryption/decryption.
+
+ This class is only used for TypeChecking (for IDEs, mypy, etc).
+
+ Note: The rotate() method exists on _RealFernet but is not part of this
Protocol.
+ rotate() should only be called from CLI commands where encryption is
guaranteed to be
+ enabled. _NullFernet (used in unit tests without FERNET_KEY) does not
support rotation.
+ """
+
+ is_encrypted: bool
+
+ def decrypt(self, msg: bytes | str, ttl: int | None = None) -> bytes:
+ """Decrypt with Fernet."""
+ ...
+
+ def encrypt(self, msg: bytes) -> bytes:
+ """Encrypt with Fernet."""
+ ...
+
+
+class _NullFernet:
+ """
+ A "Null" encryptor class that doesn't encrypt or decrypt but that presents
a similar interface to Fernet.
+
+ The purpose of this is to make the rest of the code not have to know the
+ difference, and to only display the message once, not 20 times when
+ `airflow db migrate` is run.
+ """
+
+ is_encrypted = False
+
+ def decrypt(self, msg: bytes | str, ttl: int | None = None) -> bytes:
+ """Decrypt with Fernet."""
+ if isinstance(msg, bytes):
+ return msg
+ if isinstance(msg, str):
+ return msg.encode("utf-8")
+ raise ValueError(f"Expected bytes or str, got {type(msg)}")
+
+ def encrypt(self, msg: bytes) -> bytes:
+ """Encrypt with Fernet."""
+ return msg
+
+
+class _RealFernet:
+ """
+ A wrapper around the real Fernet to set is_encrypted to True.
+
+ This class is only used internally to avoid changing the interface of
+ the get_fernet function.
+ """
+
+ is_encrypted = True
+
+ def __init__(self, fernet: MultiFernet):
+ self._fernet = fernet
+
+ def decrypt(self, msg: bytes | str, ttl: int | None = None) -> bytes:
+ """Decrypt with Fernet."""
+ return self._fernet.decrypt(msg, ttl)
+
+ def encrypt(self, msg: bytes) -> bytes:
+ """Encrypt with Fernet."""
+ return self._fernet.encrypt(msg)
+
+ def rotate(self, msg: bytes | str) -> bytes:
+ """Rotate the Fernet key for the given message."""
+ return self._fernet.rotate(msg)
+
+
+@cache
+def get_fernet() -> FernetProtocol:
+ """
+ Deferred load of Fernet key from SDK configuration.
+
+ This function could fail either because Cryptography is not installed
+ or because the Fernet key is invalid.
+
+ :return: Fernet object
+ :raises: airflow.sdk.exceptions.AirflowException if there's a problem
trying to load Fernet
+ """
+ from cryptography.fernet import Fernet, MultiFernet
+
+ from airflow.sdk.configuration import conf
+ from airflow.sdk.exceptions import AirflowException
+
+ try:
+ fernet_key = conf.get("core", "FERNET_KEY")
+ if not fernet_key:
+ log.warning("empty cryptography key - values will not be stored
encrypted.")
+ return _NullFernet()
+
+ fernet = MultiFernet([Fernet(fernet_part.encode("utf-8")) for
fernet_part in fernet_key.split(",")])
+ return _RealFernet(fernet)
+ except (ValueError, TypeError) as value_error:
+ raise AirflowException(f"Could not create Fernet object:
{value_error}")
diff --git a/task-sdk/src/airflow/sdk/serde/serializers/deltalake.py
b/task-sdk/src/airflow/sdk/serde/serializers/deltalake.py
index 7999144e0e3..672ae33ae06 100644
--- a/task-sdk/src/airflow/sdk/serde/serializers/deltalake.py
+++ b/task-sdk/src/airflow/sdk/serde/serializers/deltalake.py
@@ -37,7 +37,7 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
if not isinstance(o, DeltaTable):
return "", "", 0, False
- from airflow.models.crypto import get_fernet
+ from airflow.sdk.crypto import get_fernet
# we encrypt the information here until we have as part of the
# storage options can have sensitive information
@@ -58,7 +58,7 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
def deserialize(cls: type, version: int, data: dict):
from deltalake.table import DeltaTable
- from airflow.models.crypto import get_fernet
+ from airflow.sdk.crypto import get_fernet
if version > __version__:
raise TypeError("serialized version is newer than class version")
diff --git a/task-sdk/src/airflow/sdk/serde/serializers/iceberg.py
b/task-sdk/src/airflow/sdk/serde/serializers/iceberg.py
index ee2427836ff..b242ee6d4b5 100644
--- a/task-sdk/src/airflow/sdk/serde/serializers/iceberg.py
+++ b/task-sdk/src/airflow/sdk/serde/serializers/iceberg.py
@@ -37,7 +37,7 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
if not isinstance(o, Table):
return "", "", 0, False
- from airflow.models.crypto import get_fernet
+ from airflow.sdk.crypto import get_fernet
# we encrypt the catalog information here until we have
# global catalog management in airflow and the properties
@@ -59,7 +59,7 @@ def deserialize(cls: type, version: int, data: dict):
from pyiceberg.catalog import load_catalog
from pyiceberg.table import Table
- from airflow.models.crypto import get_fernet
+ from airflow.sdk.crypto import get_fernet
if version > __version__:
raise TypeError("serialized version is newer than class version")
diff --git a/task-sdk/tests/task_sdk/docs/test_public_api.py
b/task-sdk/tests/task_sdk/docs/test_public_api.py
index 2186cd32b12..f53887ec5c9 100644
--- a/task-sdk/tests/task_sdk/docs/test_public_api.py
+++ b/task-sdk/tests/task_sdk/docs/test_public_api.py
@@ -61,6 +61,7 @@ def test_airflow_sdk_no_unexpected_exports():
"observability",
"plugins_manager",
"listener",
+ "crypto",
"providers_manager_runtime",
}
unexpected = actual - public - ignore
diff --git a/task-sdk/tests/task_sdk/serde/test_serializers.py
b/task-sdk/tests/task_sdk/serde/test_serializers.py
index f8a57caa50f..a090e8d6b13 100644
--- a/task-sdk/tests/task_sdk/serde/test_serializers.py
+++ b/task-sdk/tests/task_sdk/serde/test_serializers.py
@@ -29,6 +29,7 @@ import pandas as pd
import pendulum
import pendulum.tz
import pytest
+from cryptography.fernet import Fernet
from dateutil.tz import tzutc
from kubernetes.client import models as k8s
from packaging import version
@@ -42,6 +43,7 @@ from airflow.sdk.definitions.param import Param, ParamsDict
from airflow.sdk.serde import CLASSNAME, DATA, VERSION, decode, deserialize,
serialize
from airflow.sdk.serde.serializers import builtin
+from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.markers import
skip_if_force_lowest_dependencies_marker
PENDULUM3 = version.parse(metadata.version("pendulum")).major == 3
@@ -284,6 +286,7 @@ class TestSerializers:
with pytest.raises(TypeError, match=msg):
deserialize(klass, version, data)
+ @conf_vars({("core", "fernet_key"): Fernet.generate_key().decode()})
def test_iceberg(self):
pytest.importorskip("pyiceberg", minversion="2.0.0")
from pyiceberg.catalog import Catalog
@@ -310,6 +313,7 @@ class TestSerializers:
mock_load_catalog.assert_called_with("catalog", uri=uri)
mock_load_table.assert_called_with((identifier[1], identifier[2]))
+ @conf_vars({("core", "fernet_key"): Fernet.generate_key().decode()})
def test_deltalake(self):
deltalake = pytest.importorskip("deltalake")
diff --git a/task-sdk/tests/task_sdk/test_crypto.py
b/task-sdk/tests/task_sdk/test_crypto.py
new file mode 100644
index 00000000000..b1acbb5e718
--- /dev/null
+++ b/task-sdk/tests/task_sdk/test_crypto.py
@@ -0,0 +1,137 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pytest
+from cryptography.fernet import Fernet
+
+from airflow.sdk.crypto import _NullFernet, _RealFernet, get_fernet
+from airflow.sdk.exceptions import AirflowException
+
+from tests_common.test_utils.config import conf_vars
+
+
+class TestNullFernet:
+ def test_decryption_of_invalid_type(self):
+ """Should raise ValueError for non-string/bytes input."""
+ null_fernet = _NullFernet()
+ with pytest.raises(ValueError, match="Expected bytes or str, got
<class 'int'>"):
+ null_fernet.decrypt(123) # type: ignore[arg-type]
+
+ def test_encrypt_decrypt_roundtrip(self):
+ """Should preserve data through encrypt/decrypt cycle."""
+ null_fernet = _NullFernet()
+ original = b"test message"
+ encrypted = null_fernet.encrypt(original)
+ decrypted = null_fernet.decrypt(encrypted)
+ assert decrypted == original
+ assert encrypted == original
+
+
+class TestRealFernet:
+ def test_encryption(self):
+ """Encryption should produce encrypted output different from input."""
+ from cryptography.fernet import MultiFernet
+
+ key = Fernet.generate_key()
+ real_fernet = _RealFernet(MultiFernet([Fernet(key)]))
+
+ msg = b"secret message"
+ encrypted = real_fernet.encrypt(msg)
+
+ assert encrypted != msg
+ assert real_fernet.decrypt(encrypted) == msg
+
+ def test_rotate_reencrypt_with_primary_key(self):
+ """rotate() should re-encrypt data with the primary key."""
+ from cryptography.fernet import MultiFernet
+
+ key1 = Fernet.generate_key()
+ key2 = Fernet.generate_key()
+
+ # encrypt with key1 only
+ encrypted_with_key1 = Fernet(key1).encrypt(b"rotate test")
+
+ # MultiFernet with key2 as primary, key1 as fallback
+ multi = MultiFernet([Fernet(key2), Fernet(key1)])
+ real_fernet = _RealFernet(multi)
+
+ # rotate should re-encrypt with key2
+ rotated = real_fernet.rotate(encrypted_with_key1)
+
+ # key2 should be able to decrypt
+ assert Fernet(key2).decrypt(rotated) == b"rotate test"
+ assert rotated != encrypted_with_key1
+
+
+class TestGetFernet:
+ @conf_vars({("core", "FERNET_KEY"): ""})
+ def test_empty_key(self):
+ get_fernet.cache_clear()
+ fernet = get_fernet()
+
+ assert not fernet.is_encrypted
+ test_data = b"unencrypted"
+ assert fernet.encrypt(test_data) == test_data
+ assert fernet.decrypt(test_data) == test_data
+
+ @conf_vars({("core", "FERNET_KEY"): Fernet.generate_key().decode()})
+ def test_valid_key_encrypts_data(self):
+ """Valid FERNET_KEY should return working encryption."""
+ get_fernet.cache_clear()
+ fernet = get_fernet()
+
+ assert fernet.is_encrypted
+ original = b"sensitive data"
+ encrypted = fernet.encrypt(original)
+ assert encrypted != original
+ assert fernet.decrypt(encrypted) == original
+
+ @conf_vars({("core", "FERNET_KEY"): "invalid-key"})
+ def test_invalid_key(self):
+ """Invalid FERNET_KEY should raise a AirflowException."""
+ get_fernet.cache_clear()
+ with pytest.raises(AirflowException, match="Could not create Fernet
object"):
+ get_fernet()
+
+ def test_multiple_keys(self):
+ """Multiple comma separated keys should support key rotation."""
+ key1 = Fernet.generate_key()
+ key2 = Fernet.generate_key()
+
+ # encrypt with key1
+ data_encrypted_with_key1 = Fernet(key1).encrypt(b"secret data")
+
+ # get_fernet with both keys (key2 primary, key1 fallback)
+ with conf_vars({("core", "FERNET_KEY"):
f"{key2.decode()},{key1.decode()}"}):
+ get_fernet.cache_clear()
+ fernet = get_fernet()
+
+ # decrypt data encrypted with old key1
+ assert fernet.decrypt(data_encrypted_with_key1) == b"secret data"
+ new_encrypted = fernet.encrypt(b"new")
+ assert Fernet(key2).decrypt(new_encrypted) == b"new"
+
+ @conf_vars({("core", "FERNET_KEY"): Fernet.generate_key().decode()})
+ def test_caching(self):
+ """get_fernet() should return cached instance."""
+ get_fernet.cache_clear()
+
+ fernet1 = get_fernet()
+ fernet2 = get_fernet()
+
+ assert fernet1 is fernet2