This is an automated email from the ASF dual-hosted git repository.

amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 458ac63f7cb Supporting Variable.set in execution time with task SDK 
(#49005)
458ac63f7cb is described below

commit 458ac63f7cb5b671483f0e521f1cb0ad57508675
Author: Amogh Desai <[email protected]>
AuthorDate: Thu Apr 10 12:56:00 2025 +0530

    Supporting Variable.set in execution time with task SDK (#49005)
    
    * Supporting Variable.set in execution time with task SDK
    
    * adding tests for variable set
    
    * fixing static checks and review comments from ash
---
 .../src/airflow/dag_processing/processor.py        |  5 ++-
 airflow-core/src/airflow/models/variable.py        | 24 +++++++++++
 .../tests/unit/dag_processing/test_processor.py    | 30 ++++++++++++++
 task-sdk/src/airflow/sdk/definitions/variable.py   | 10 +++++
 task-sdk/src/airflow/sdk/execution_time/context.py | 46 ++++++++++++++++++++++
 .../tests/task_sdk/definitions/test_variables.py   | 36 ++++++++++++++++-
 6 files changed, 149 insertions(+), 2 deletions(-)

diff --git a/airflow-core/src/airflow/dag_processing/processor.py 
b/airflow-core/src/airflow/dag_processing/processor.py
index b6e07242dce..365d771a081 100644
--- a/airflow-core/src/airflow/dag_processing/processor.py
+++ b/airflow-core/src/airflow/dag_processing/processor.py
@@ -38,6 +38,7 @@ from airflow.sdk.execution_time.comms import (
     ErrorResponse,
     GetConnection,
     GetVariable,
+    PutVariable,
     VariableResult,
 )
 from airflow.sdk.execution_time.supervisor import WatchedSubprocess
@@ -53,7 +54,7 @@ if TYPE_CHECKING:
     from airflow.typing_compat import Self
 
 ToManager = Annotated[
-    Union["DagFileParsingResult", GetConnection, GetVariable],
+    Union["DagFileParsingResult", GetConnection, GetVariable, PutVariable],
     Field(discriminator="type"),
 ]
 
@@ -290,6 +291,8 @@ class DagFileProcessorProcess(WatchedSubprocess):
                 dump_opts = {"exclude_unset": True}
             else:
                 resp = var
+        elif isinstance(msg, PutVariable):
+            self.client.variables.set(msg.key, msg.value, msg.description)
         else:
             log.error("Unhandled request", msg=msg)
             return
diff --git a/airflow-core/src/airflow/models/variable.py 
b/airflow-core/src/airflow/models/variable.py
index e0b91686285..9c6af8ab727 100644
--- a/airflow-core/src/airflow/models/variable.py
+++ b/airflow-core/src/airflow/models/variable.py
@@ -201,6 +201,30 @@ class Variable(Base, LoggingMixin):
         """
         # check if the secret exists in the custom secrets' backend.
         Variable.check_for_write_conflict(key=key)
+
+        # TODO: This is not the best way of having compat, but it's "better 
than erroring" for now. This still
+        # means SQLA etc is loaded, but we can't avoid that unless/until we 
add import shims as a big
+        # back-compat layer
+
+        # If this is set it means are in some kind of execution context (Task, 
Dag Parse or Triggerer perhaps)
+        # and should use the Task SDK API server path
+        if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), 
"SUPERVISOR_COMMS"):
+            warnings.warn(
+                "Using Variable.set from `airflow.models` is deprecated. 
Please use `from airflow.sdk import"
+                "Variable` instead",
+                DeprecationWarning,
+                stacklevel=1,
+            )
+            from airflow.sdk import Variable as TaskSDKVariable
+
+            TaskSDKVariable.set(
+                key=key,
+                value=value,
+                description=description,
+                serialize_json=serialize_json,
+            )
+            return
+
         if serialize_json:
             stored_value = json.dumps(value, indent=2)
         else:
diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py 
b/airflow-core/tests/unit/dag_processing/test_processor.py
index ca9670f81ae..4135c617f39 100644
--- a/airflow-core/tests/unit/dag_processing/test_processor.py
+++ b/airflow-core/tests/unit/dag_processing/test_processor.py
@@ -181,6 +181,36 @@ class TestDagFileProcessor:
         if result.import_errors:
             assert "VARIABLE_NOT_FOUND" in 
next(iter(result.import_errors.values()))
 
+    def test_top_level_variable_set(self, tmp_path: pathlib.Path):
+        from airflow.models.variable import Variable as VariableORM
+
+        logger_filehandle = MagicMock()
+
+        def dag_in_a_fn():
+            from airflow.sdk import DAG, Variable
+
+            Variable.set(key="mykey", value="myvalue")
+            with DAG(f"test_{Variable.get('mykey')}"):
+                ...
+
+        path = write_dag_in_a_fn_to_file(dag_in_a_fn, tmp_path)
+        proc = DagFileProcessorProcess.start(
+            id=1, path=path, bundle_path=tmp_path, callbacks=[], 
logger_filehandle=logger_filehandle
+        )
+
+        while not proc.is_ready:
+            proc._service_subprocess(0.1)
+
+        with create_session() as session:
+            result = proc.parsing_result
+            assert result is not None
+            assert result.import_errors == {}
+            assert result.serialized_dags[0].dag_id == "test_myvalue"
+
+            all_vars = session.query(VariableORM).all()
+            assert len(all_vars) == 1
+            assert all_vars[0].key == "mykey"
+
     def test_top_level_connection_access(self, tmp_path: pathlib.Path, 
monkeypatch: pytest.MonkeyPatch):
         logger_filehandle = MagicMock()
 
diff --git a/task-sdk/src/airflow/sdk/definitions/variable.py 
b/task-sdk/src/airflow/sdk/definitions/variable.py
index 87b0ee29fab..9e80b8a5667 100644
--- a/task-sdk/src/airflow/sdk/definitions/variable.py
+++ b/task-sdk/src/airflow/sdk/definitions/variable.py
@@ -55,3 +55,13 @@ class Variable:
             if e.error.error == ErrorType.VARIABLE_NOT_FOUND and default is 
not NOTSET:
                 return default
             raise
+
+    @classmethod
+    def set(cls, key: str, value: Any, description: str | None = None, 
serialize_json: bool = False) -> None:
+        from airflow.sdk.exceptions import AirflowRuntimeError
+        from airflow.sdk.execution_time.context import _set_variable
+
+        try:
+            return _set_variable(key, value, description, 
serialize_json=serialize_json)
+        except AirflowRuntimeError as e:
+            log.exception(e)
diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py 
b/task-sdk/src/airflow/sdk/execution_time/context.py
index 2ff7dbbda08..fee91f3efaf 100644
--- a/task-sdk/src/airflow/sdk/execution_time/context.py
+++ b/task-sdk/src/airflow/sdk/execution_time/context.py
@@ -211,6 +211,52 @@ def _get_variable(key: str, deserialize_json: bool) -> Any:
     return variable.value
 
 
+def _set_variable(key: str, value: Any, description: str | None = None, 
serialize_json: bool = False) -> None:
+    # TODO: This should probably be moved to a separate module like 
`airflow.sdk.execution_time.comms`
+    #   or `airflow.sdk.execution_time.variable`
+    #   A reason to not move it to `airflow.sdk.execution_time.comms` is that 
it
+    #   will make that module depend on Task SDK, which is not ideal because 
we intend to
+    #   keep Task SDK as a separate package than execution time mods.
+    import json
+
+    from airflow.sdk.execution_time.comms import PutVariable
+    from airflow.sdk.execution_time.supervisor import 
ensure_secrets_backend_loaded
+    from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+
+    # check for write conflicts on the worker
+    for secrets_backend in ensure_secrets_backend_loaded():
+        try:
+            var_val = secrets_backend.get_variable(key=key)
+            if var_val is not None:
+                _backend_name = type(secrets_backend).__name__
+                log.warning(
+                    "The variable %s is defined in the %s secrets backend, 
which takes "
+                    "precedence over reading from the database. The value in 
the database will be "
+                    "updated, but to read it you have to delete the 
conflicting variable "
+                    "from %s",
+                    key,
+                    _backend_name,
+                    _backend_name,
+                )
+        except Exception:
+            log.exception(
+                "Unable to retrieve variable from secrets backend (%s). 
Checking subsequent secrets backend.",
+                type(secrets_backend).__name__,
+            )
+
+    try:
+        if serialize_json:
+            value = json.dumps(value, indent=2)
+    except Exception as e:
+        log.exception(e)
+
+    # It is best to have lock everywhere or nowhere on the SUPERVISOR_COMMS, 
lock was
+    # primarily added for triggers but it doesn't make sense to have it in 
some places
+    # and not in the rest. A lot of this will be simplified by 
https://github.com/apache/airflow/issues/46426
+    with SUPERVISOR_COMMS.lock:
+        SUPERVISOR_COMMS.send_request(log=log, msg=PutVariable(key=key, 
value=value, description=description))
+
+
 class ConnectionAccessor:
     """Wrapper to access Connection entries in template."""
 
diff --git a/task-sdk/tests/task_sdk/definitions/test_variables.py 
b/task-sdk/tests/task_sdk/definitions/test_variables.py
index 6560bdee903..242c5af407b 100644
--- a/task-sdk/tests/task_sdk/definitions/test_variables.py
+++ b/task-sdk/tests/task_sdk/definitions/test_variables.py
@@ -17,13 +17,14 @@
 # under the License.
 from __future__ import annotations
 
+import json
 from unittest import mock
 
 import pytest
 
 from airflow.configuration import initialize_secrets_backends
 from airflow.sdk import Variable
-from airflow.sdk.execution_time.comms import VariableResult
+from airflow.sdk.execution_time.comms import PutVariable, VariableResult
 from airflow.secrets import DEFAULT_SECRETS_SEARCH_PATH_WORKERS
 
 from tests_common.test_utils.config import conf_vars
@@ -55,6 +56,39 @@ class TestVariables:
         assert var is not None
         assert var == expected_value
 
+    @pytest.mark.parametrize(
+        "key, value, description, serialize_json",
+        [
+            pytest.param(
+                "key",
+                "value",
+                "description",
+                False,
+                id="simple-value",
+            ),
+            pytest.param(
+                "key2",
+                {"hi": "there", "hello": 42, "flag": True},
+                "description2",
+                True,
+                id="serialize-json-value",
+            ),
+        ],
+    )
+    def test_var_set(self, key, value, description, serialize_json, 
mock_supervisor_comms):
+        Variable.set(key=key, value=value, description=description, 
serialize_json=serialize_json)
+
+        expected_value = value
+        if serialize_json:
+            expected_value = json.dumps(value, indent=2)
+
+        mock_supervisor_comms.send_request.assert_called_once_with(
+            log=mock.ANY,
+            msg=PutVariable(
+                key=key, value=expected_value, description=description, 
serialize_json=serialize_json
+            ),
+        )
+
 
 class TestVariableFromSecrets:
     def test_var_get_from_secrets_found(self, mock_supervisor_comms, tmp_path):

Reply via email to