pierrejeambrun commented on code in PR #48239:
URL: https://github.com/apache/airflow/pull/48239#discussion_r2016309674


##########
airflow-core/tests/unit/jobs/test_triggerer_job.py:
##########
@@ -589,3 +595,125 @@ def test_failed_trigger(session, dag_maker, 
supervisor_builder):
     assert task_instance.next_method == "__fail__"
     assert task_instance.next_kwargs["error"] == "Trigger failure"
     assert task_instance.next_kwargs["traceback"][-1] == "ModuleNotFoundError: 
No module named 'fake'\n"
+
+
+class CustomTrigger(BaseTrigger):
+    """Custom Trigger that will access one Variable and one Connection."""
+
+    def __init__(self, dag_id, run_id, task_id, map_index):
+        self.dag_id = dag_id
+        self.run_id = run_id
+        self.task_id = task_id
+        self.map_index = map_index
+
+    async def run(self, **args) -> AsyncIterator[TriggerEvent]:
+        import attrs
+
+        from airflow.sdk import Variable
+        from airflow.sdk.execution_time.xcom import XCom
+
+        conn = await sync_to_async(BaseHook.get_connection)("test_connection")
+        self.log.info("Loaded conn %s", conn.conn_id)
+
+        variable = await sync_to_async(Variable.get)("test_variable")
+        self.log.info("Loaded variable %s", variable)
+
+        xcom = await sync_to_async(XCom.get_one)(
+            key="test_xcom",
+            dag_id=self.dag_id,
+            run_id=self.run_id,
+            task_id=self.task_id,
+            map_index=self.map_index,
+        )
+        self.log.info("Loaded XCom %s", xcom)
+
+        yield TriggerEvent({"connection": attrs.asdict(conn), "variable": 
variable, "xcom": xcom})
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        return (
+            f"{type(self).__module__}.{type(self).__qualname__}",
+            {"dag_id": self.dag_id, "run_id": self.run_id, "task_id": 
self.task_id},
+        )
+
+
+class DummyTriggerRunnerSupervisor(TriggerRunnerSupervisor):
+    """
+    Make sure that the Supervisor stops after handling the events and do not 
keep running forever so the
+    test can continue.
+    """
+
+    def handle_events(self):
+        self.stop = bool(self.events)
+        super().handle_events()
+
+
[email protected]
+# @patch("airflow.providers.standard.triggers.temporal.CustomTrigger", 
return_value=CustomTrigger)
+async def test_trigger_can_access_variables_and_connections(session, 
dag_maker, supervisor_builder):
+    """
+    Checks that the trigger will successfully access Variables, Connections 
and XCom.
+
+    This is the Supervisor side of the error reported in 
TestTriggerRunner::test_invalid_trigger
+    """
+
+    # Create the test DAG and task
+    with dag_maker(dag_id="trigger_accessing_variable_and_connection", 
session=session):
+        EmptyOperator(task_id="dummy1")
+    dr = dag_maker.create_dagrun()
+    task_instance = dr.task_instances[0]
+    # Make a task instance based on that and tie it to the trigger
+    task_instance.state = TaskInstanceState.DEFERRED
+
+    # Create a Trigger
+    trigger = CustomTrigger(dag_id=dr.dag_id, run_id=dr.run_id, 
task_id=task_instance.task_id, map_index=-1)
+    trigger_orm = Trigger(
+        classpath=trigger.serialize()[0],
+        kwargs={"dag_id": dr.dag_id, "run_id": dr.run_id, "task_id": 
task_instance.task_id, "map_index": -1},
+    )
+    trigger_orm.id = 1
+    session.add(trigger_orm)
+    session.commit()
+    task_instance.trigger_id = trigger_orm.id
+
+    # Create the appropriate Connection, Variable and XCom
+    connection = Connection(conn_id="test_connection", conn_type="http")
+    variable = Variable(key="test_variable", val="some_variable_value")
+    XComModel.set(
+        key="test_xcom",
+        value="some_xcom_value",
+        task_id=task_instance.task_id,
+        dag_id=dr.dag_id,
+        run_id=dr.run_id,
+        map_index=-1,
+        session=session,
+    )
+    session.add(connection)
+    session.add(variable)
+
+    job = Job()
+    session.add(job)
+    session.commit()
+
+    supervisor = DummyTriggerRunnerSupervisor.start(job=job, capacity=1, 
logger=None)
+    supervisor.run()
+
+    task_instance.refresh_from_db()
+    assert task_instance.state == TaskInstanceState.SCHEDULED
+    assert task_instance.next_method != "__fail__"
+    assert task_instance.next_kwargs == {
+        "event": {
+            "connection": {
+                "conn_id": "test_connection",
+                "conn_type": "http",
+                "description": None,
+                "host": None,
+                "schema": None,
+                "login": None,
+                "password": None,
+                "port": None,
+                "extra": None,
+            },
+            "variable": "some_variable_value",
+            "xcom": '"some_xcom_value"',

Review Comment:
   Just to bring attention there, is that expected. (double quotes `"'`)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to