This is an automated email from the ASF dual-hosted git repository.

jedcunningham pushed a commit to branch v2-9-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit c9cc72608d66b1481ef1029499ce54e30b178b12
Author: Jed Cunningham <[email protected]>
AuthorDate: Thu Apr 25 15:20:40 2024 -0400

    Fix trigger kwarg encryption migration (#39246)
    
    Do the encryption in the migration itself, and fix support for offline
    migrations as well.
    
    The offline up migration won't actually encrypt the trigger kwargs as there
    isn't a safe way to accomplish that, so the decryption processes checks
    and short circuits if it isn't encrypted.
    
    The offline down migration will now print out a warning that the offline
    migration will fail if there are any running triggers. I think this is
    the best we can do for that scenario (and folks willing to do offline
    migrations will hopefully be able to understand the situation).
    
    This also solves the "encrypting the already encrypted kwargs" bug in
    2.9.0.
    
    (cherry picked from commit adeb7f7cba2ab2b16be2e006c17e140fe91fdf77)
---
 .../0140_2_9_0_update_trigger_kwargs_type.py       | 46 +++++++++++++++++++---
 airflow/models/trigger.py                          | 10 ++++-
 airflow/utils/db.py                                | 39 ------------------
 docs/apache-airflow/img/airflow_erd.sha256         |  2 +-
 docs/apache-airflow/img/airflow_erd.svg            |  4 +-
 docs/apache-airflow/migrations-ref.rst             |  2 +-
 tests/models/test_trigger.py                       | 17 ++++++++
 7 files changed, 70 insertions(+), 50 deletions(-)

diff --git 
a/airflow/migrations/versions/0140_2_9_0_update_trigger_kwargs_type.py 
b/airflow/migrations/versions/0140_2_9_0_update_trigger_kwargs_type.py
index dbde1201e4..2d57686e43 100644
--- a/airflow/migrations/versions/0140_2_9_0_update_trigger_kwargs_type.py
+++ b/airflow/migrations/versions/0140_2_9_0_update_trigger_kwargs_type.py
@@ -16,18 +16,22 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""update trigger kwargs type
+"""update trigger kwargs type and encrypt
 
 Revision ID: 1949afb29106
 Revises: ee1467d4aa35
 Create Date: 2024-03-17 22:09:09.406395
 
 """
+import json
+from textwrap import dedent
+
+from alembic import context, op
 import sqlalchemy as sa
+from sqlalchemy.orm import lazyload
 
+from airflow.serialization.serialized_objects import BaseSerialization
 from airflow.models.trigger import Trigger
-from alembic import op
-
 from airflow.utils.sqlalchemy import ExtendedJSON
 
 # revision identifiers, used by Alembic.
@@ -38,13 +42,43 @@ depends_on = None
 airflow_version = "2.9.0"
 
 
+def get_session() -> sa.orm.Session:
+    conn = op.get_bind()
+    sessionmaker = sa.orm.sessionmaker()
+    return sessionmaker(bind=conn)
+
 def upgrade():
-    """Update trigger kwargs type to string"""
+    """Update trigger kwargs type to string and encrypt"""
     with op.batch_alter_table("trigger") as batch_op:
         batch_op.alter_column("kwargs", type_=sa.Text(), )
 
+    if not context.is_offline_mode():
+        session = get_session()
+        try:
+            for trigger in 
session.query(Trigger).options(lazyload(Trigger.task_instance)):
+                trigger.kwargs = trigger.kwargs
+            session.commit()
+        finally:
+            session.close()
+
 
 def downgrade():
-    """Unapply update trigger kwargs type to string"""
+    """Unapply update trigger kwargs type to string and encrypt"""
+    if context.is_offline_mode():
+        print(dedent("""
+        ------------
+        --  WARNING: Unable to decrypt trigger kwargs automatically in offline 
mode!
+        --  If any trigger rows exist when you do an offline downgrade, the 
migration will fail.
+        ------------
+        """))
+    else:
+        session = get_session()
+        try:
+            for trigger in 
session.query(Trigger).options(lazyload(Trigger.task_instance)):
+                trigger.encrypted_kwargs = 
json.dumps(BaseSerialization.serialize(trigger.kwargs))
+            session.commit()
+        finally:
+            session.close()
+
     with op.batch_alter_table("trigger") as batch_op:
-        batch_op.alter_column("kwargs", type_=ExtendedJSON(), 
postgresql_using="kwargs::json")
+        batch_op.alter_column("kwargs", type_=ExtendedJSON(), 
postgresql_using='kwargs::json')
diff --git a/airflow/models/trigger.py b/airflow/models/trigger.py
index 32fec3b343..cde76ca241 100644
--- a/airflow/models/trigger.py
+++ b/airflow/models/trigger.py
@@ -116,7 +116,15 @@ class Trigger(Base):
         from airflow.models.crypto import get_fernet
         from airflow.serialization.serialized_objects import BaseSerialization
 
-        decrypted_kwargs = 
json.loads(get_fernet().decrypt(encrypted_kwargs.encode("utf-8")).decode("utf-8"))
+        # We weren't able to encrypt the kwargs in all migration paths,
+        # so we need to handle the case where they are not encrypted.
+        # Triggers aren't long lasting, so we can skip encrypting them now.
+        if encrypted_kwargs.startswith("{"):
+            decrypted_kwargs = json.loads(encrypted_kwargs)
+        else:
+            decrypted_kwargs = json.loads(
+                
get_fernet().decrypt(encrypted_kwargs.encode("utf-8")).decode("utf-8")
+            )
 
         return BaseSerialization.deserialize(decrypted_kwargs)
 
diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index c0d282a587..b7997498bb 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -972,33 +972,6 @@ def synchronize_log_template(*, session: Session = 
NEW_SESSION) -> None:
         session.add(LogTemplate(filename=filename, 
elasticsearch_id=elasticsearch_id))
 
 
-def encrypt_trigger_kwargs(*, session: Session) -> None:
-    """Encrypt trigger kwargs."""
-    from airflow.models.trigger import Trigger
-    from airflow.serialization.serialized_objects import BaseSerialization
-
-    for trigger in session.query(Trigger):
-        # convert serialized dict to string and encrypt it
-        trigger.kwargs = 
BaseSerialization.deserialize(json.loads(trigger.encrypted_kwargs))
-    session.commit()
-
-
-def decrypt_trigger_kwargs(*, session: Session) -> None:
-    """Decrypt trigger kwargs."""
-    from airflow.models.trigger import Trigger
-    from airflow.serialization.serialized_objects import BaseSerialization
-
-    if not inspect(session.bind).has_table(Trigger.__tablename__):
-        # table does not exist, nothing to do
-        # this can happen when we downgrade to an old version before the 
Trigger table was added
-        return
-
-    for trigger in session.scalars(select(Trigger.encrypted_kwargs)):
-        # decrypt the string and convert it to serialized dict
-        trigger.encrypted_kwargs = 
json.dumps(BaseSerialization.serialize(trigger.kwargs))
-    session.commit()
-
-
 def check_conn_id_duplicates(session: Session) -> Iterable[str]:
     """
     Check unique conn_id in connection table.
@@ -1666,12 +1639,6 @@ def upgradedb(
         _reserialize_dags(session=session)
     add_default_pool_if_not_exists(session=session)
     synchronize_log_template(session=session)
-    if _revision_greater(
-        config,
-        _REVISION_HEADS_MAP["2.9.0"],
-        _get_current_revision(session=session),
-    ):
-        encrypt_trigger_kwargs(session=session)
 
 
 @provide_session
@@ -1744,12 +1711,6 @@ def downgrade(*, to_revision, from_revision=None, 
show_sql_only=False, session:
         else:
             log.info("Applying downgrade migrations.")
             command.downgrade(config, revision=to_revision, sql=show_sql_only)
-            if _revision_greater(
-                config,
-                _REVISION_HEADS_MAP["2.9.0"],
-                to_revision,
-            ):
-                decrypt_trigger_kwargs(session=session)
 
 
 def drop_airflow_models(connection):
diff --git a/docs/apache-airflow/img/airflow_erd.sha256 
b/docs/apache-airflow/img/airflow_erd.sha256
index 09f84daea2..6f623a5a64 100644
--- a/docs/apache-airflow/img/airflow_erd.sha256
+++ b/docs/apache-airflow/img/airflow_erd.sha256
@@ -1 +1 @@
-2a24225537326f38be5df14e0b7a8dca867122093e0fa932f1a11ac12d1fb11c
\ No newline at end of file
+3eb263f117248f914f64bf7cf44757526ecc00f222677629a11602f3bae7cdf0
\ No newline at end of file
diff --git a/docs/apache-airflow/img/airflow_erd.svg 
b/docs/apache-airflow/img/airflow_erd.svg
index dc32fe0566..3dbb55dd4b 100644
--- a/docs/apache-airflow/img/airflow_erd.svg
+++ b/docs/apache-airflow/img/airflow_erd.svg
@@ -1375,14 +1375,14 @@
 <g id="edge43" class="edge">
 <title>task_instance&#45;&#45;xcom</title>
 <path fill="none" stroke="#7f7f7f" stroke-dasharray="5,2" 
d="M1198.1,-831.83C1228.72,-828.75 1260.55,-825.32 1290.36,-821.86"/>
-<text text-anchor="start" x="1280.36" y="-825.66" font-family="Times,serif" 
font-size="14.00">1</text>
+<text text-anchor="start" x="1259.36" y="-825.66" font-family="Times,serif" 
font-size="14.00">0..N</text>
 <text text-anchor="start" x="1198.1" y="-835.63" font-family="Times,serif" 
font-size="14.00">1</text>
 </g>
 <!-- task_instance&#45;&#45;xcom -->
 <g id="edge44" class="edge">
 <title>task_instance&#45;&#45;xcom</title>
 <path fill="none" stroke="#7f7f7f" stroke-dasharray="5,2" 
d="M1198.1,-845.13C1228.72,-842.41 1260.55,-838.85 1290.36,-834.82"/>
-<text text-anchor="start" x="1259.36" y="-838.62" font-family="Times,serif" 
font-size="14.00">0..N</text>
+<text text-anchor="start" x="1280.36" y="-838.62" font-family="Times,serif" 
font-size="14.00">1</text>
 <text text-anchor="start" x="1198.1" y="-848.93" font-family="Times,serif" 
font-size="14.00">1</text>
 </g>
 <!-- rendered_task_instance_fields -->
diff --git a/docs/apache-airflow/migrations-ref.rst 
b/docs/apache-airflow/migrations-ref.rst
index 76ad56631f..e8faea6d0d 100644
--- a/docs/apache-airflow/migrations-ref.rst
+++ b/docs/apache-airflow/migrations-ref.rst
@@ -39,7 +39,7 @@ Here's the list of all the Database Migrations that are 
executed via when you ru
 
+---------------------------------+-------------------+-------------------+--------------------------------------------------------------+
 | Revision ID                     | Revises ID        | Airflow Version   | 
Description                                                  |
 
+=================================+===================+===================+==============================================================+
-| ``1949afb29106`` (head)         | ``ee1467d4aa35``  | ``2.9.0``         | 
update trigger kwargs type                                   |
+| ``1949afb29106`` (head)         | ``ee1467d4aa35``  | ``2.9.0``         | 
update trigger kwargs type and encrypt                       |
 
+---------------------------------+-------------------+-------------------+--------------------------------------------------------------+
 | ``ee1467d4aa35``                | ``b4078ac230a1``  | ``2.9.0``         | 
add display name for dag and task instance                   |
 
+---------------------------------+-------------------+-------------------+--------------------------------------------------------------+
diff --git a/tests/models/test_trigger.py b/tests/models/test_trigger.py
index a3dd6ce35a..6be2086f34 100644
--- a/tests/models/test_trigger.py
+++ b/tests/models/test_trigger.py
@@ -17,6 +17,7 @@
 from __future__ import annotations
 
 import datetime
+import json
 from typing import Any, AsyncIterator
 
 import pytest
@@ -27,6 +28,7 @@ from airflow.jobs.job import Job
 from airflow.jobs.triggerer_job_runner import TriggererJobRunner
 from airflow.models import TaskInstance, Trigger
 from airflow.operators.empty import EmptyOperator
+from airflow.serialization.serialized_objects import BaseSerialization
 from airflow.triggers.base import BaseTrigger, TriggerEvent
 from airflow.utils import timezone
 from airflow.utils.session import create_session
@@ -378,3 +380,18 @@ def test_serialize_sensitive_kwargs():
     assert isinstance(trigger_row.encrypted_kwargs, str)
     assert "value1" not in trigger_row.encrypted_kwargs
     assert "value2" not in trigger_row.encrypted_kwargs
+
+
+def test_kwargs_not_encrypted():
+    """
+    Tests that we don't decrypt kwargs if they aren't encrypted.
+    We weren't able to encrypt the kwargs in all migration paths.
+    """
+    trigger = Trigger(classpath="airflow.triggers.testing.SuccessTrigger", 
kwargs={})
+    # force the `encrypted_kwargs` to be unencrypted, like they would be after 
an offline upgrade
+    trigger.encrypted_kwargs = json.dumps(
+        BaseSerialization.serialize({"param1": "value1", "param2": "value2"})
+    )
+
+    assert trigger.kwargs["param1"] == "value1"
+    assert trigger.kwargs["param2"] == "value2"

Reply via email to