This is an automated email from the ASF dual-hosted git repository.

jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 824ec4a80d AIP-69: Add Executor to Edge Provider (#42048)
824ec4a80d is described below

commit 824ec4a80d41290894223dac3ed43eacf924a1d5
Author: Jens Scheffler <[email protected]>
AuthorDate: Thu Oct 17 23:38:12 2024 +0200

    AIP-69: Add Executor to Edge Provider (#42048)
    
    * Add Executor to Edge Provider
    
    * Review feeedback + small adjustments from Niko
    
    * Add explicit note about non-performance in current state
    
    * Fix adoption of tasks when restarting scheduler
    
    (cherry picked from commit 4e9c2262f567a2511d02d4acd43b821fa0df45d2)
    
    * Optimize sync call for MVP executor
    
    * Adjust new folder structure from PR 42505
    
    * Review feedback: removed cleanup_stuck_queued_tasks() and added notes 
about performance tests
---
 .../edge_executor.rst                              | 262 +++++++++++++++++++++
 docs/apache-airflow-providers-edge/index.rst       |   8 +
 .../airflow/providers/edge/executors/__init__.py   |  22 ++
 .../providers/edge/executors/edge_executor.py      | 185 +++++++++++++++
 providers/src/airflow/providers/edge/provider.yaml |   3 +
 providers/tests/edge/executors/__init__.py         |  17 ++
 .../tests/edge/executors/test_edge_executor.py     | 152 ++++++++++++
 7 files changed, 649 insertions(+)

diff --git a/docs/apache-airflow-providers-edge/edge_executor.rst 
b/docs/apache-airflow-providers-edge/edge_executor.rst
new file mode 100644
index 0000000000..d27cb5bc69
--- /dev/null
+++ b/docs/apache-airflow-providers-edge/edge_executor.rst
@@ -0,0 +1,262 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+ ..   http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+
+Edge Executor
+=============
+
+.. note::
+
+    The Edge Provider Package is an experimental preview. Features and 
stability is limited
+    and needs to be improved over time. Target is to have full support in 
Airflow 3.
+    Once Airflow 3 support contains Edge Provider, maintenance of the Airflow 
2 package will
+    be dis-continued.
+
+
+.. note::
+
+    As of Airflow 2.10.0, you can install the ``edge`` provider package to use 
this executor.
+    This can be done by installing ``apache-airflow-providers-edge`` or by 
installing Airflow
+    with the ``edge`` extra: ``pip install 'apache-airflow[edge]'``.
+
+    While it is in not-ready state, a wheel release package must be manually 
built from source tree
+    via ``breeze release-management prepare-provider-packages 
--include-not-ready-providers edge``
+    and then installed via pip from the generated wheel file.
+
+
+``EdgeExecutor`` is an option if you want to distribute tasks to workers 
distributed in different locations.
+You can use it also in parallel with other executors if needed. Change your 
``airflow.cfg`` to point
+the executor parameter to ``EdgeExecutor`` and provide the related settings.
+
+The configuration parameters of the Edge Executor can be found in the Edge 
provider's :doc:`configurations-ref`.
+
+Here are a few imperative requirements for your workers:
+
+- ``airflow`` needs to be installed, and the CLI needs to be in the path
+- Airflow configuration settings should be homogeneous across the cluster
+- Operators that are executed on the Edge Worker need to have their 
dependencies
+  met in that context. Please take a look to the respective provider package
+  documentations
+- The worker needs to have access to its ``DAGS_FOLDER``, and you need to
+  synchronize the filesystems by your own means. A common setup would be to
+  store your ``DAGS_FOLDER`` in a Git repository and sync it across machines 
using
+  Chef, Puppet, Ansible, or whatever you use to configure machines in your
+  environment. If all your boxes have a common mount point, having your
+  pipelines files shared there should work as well
+
+
+Minimum configuration for the Edge Worker to make it running is:
+
+- Section ``[core]``
+
+  - ``executor``: Executor must be set or added to be 
``airflow.providers.edge.executors.EdgeExecutor``
+  - ``internal_api_secret_key``: An encryption key must be set on webserver 
and Edge Worker component as
+    shared secret to authenticate traffic. It should be a random string like 
the fernet key
+    (but preferably not the same).
+
+- Section ``[edge]``
+
+  - ``api_enabled``: Must be set to true. It is disabled intentionally to not 
expose
+    the endpoint by default. This is the endpoint the worker connects to.
+    In a future release a dedicated API server can be started.
+  - ``api_url``: Must be set to the URL which exposes the web endpoint
+
+To kick off a worker, you need to setup Airflow and kick off the worker
+subcommand
+
+.. code-block:: bash
+
+    airflow edge worker
+
+Your worker should start picking up tasks as soon as they get fired in
+its direction. To stop a worker running on a machine you can use:
+
+.. code-block:: bash
+
+    airflow edge stop
+
+It will try to stop the worker gracefully by sending ``SIGINT`` signal to main
+process as and wait until all running tasks are completed.
+
+If you want to monitor the remote activity and worker, use the UI plugin which
+is included in the provider package and install it on the webserver and use the
+"Admin" - "Edge Worker Hosts" and "Edge Worker Jobs" pages.
+
+
+Some caveats:
+
+- Tasks can consume resources. Make sure your worker has enough resources to 
run ``worker_concurrency`` tasks
+- Queue names are limited to 256 characters
+
+See :doc:`apache-airflow:administration-and-deployment/modules_management` for 
details on how Python and Airflow manage modules.
+
+Limitations of Pre-Release
+--------------------------
+
+As this provider package is an experimental preview not all functions are 
support and not fully covered.
+If you plan to use the Edge Executor / Worker in the current stage you need to 
ensure you test properly
+before use. The following features have been initially tested and are working:
+
+- Some core operators
+
+  - ``BashOperator``
+  - ``PythonOperator``
+  - ``@task`` decorator
+  - ``@task.branch`` decorator
+  - ``@task.virtualenv`` decorator
+  - ``@task.bash`` decorator
+  - Dynamic Mapped Tasks
+  - XCom read/write
+  - Variable and Connection access
+  - Setup and Teardown tasks
+
+- Some known limitations
+
+  - Tasks that require DB access will fail - no DB connection from remote site 
is possible
+  - This also means that some direct Airflow API via Python is not possible 
(e.g. airflow.models.*)
+  - Log upload will only work if you use a single web server instance or they 
need to share one log file volume.
+  - Performance: No performance assessment and scaling tests have been made. 
The edge executor package is not
+    optimized for scalability. This will need to be considered in future 
releases. A dedicated performance
+    assessment is to be completed ensuring that in a hybrid setup other 
executors are not impacted before
+    version 1.0.0 is to be released.
+  - Stuck tasks in queue are not explicitly handled as 
``cleanup_stuck_queued_tasks()`` is not implemented.
+
+
+Architecture
+------------
+
+.. graphviz::
+
+    digraph A{
+        rankdir="TB"
+        node[shape="rectangle", style="rounded"]
+
+
+        subgraph cluster {
+            label="Cluster";
+            {rank = same; dag; database}
+            {rank = same; workers; scheduler; web}
+
+            workers[label="(Central) Workers"]
+            scheduler[label="Scheduler"]
+            web[label="Web server"]
+            database[label="Database"]
+            dag[label="DAG files"]
+
+            web->workers
+            web->database
+
+            workers->dag
+            workers->database
+
+            scheduler->dag
+            scheduler->database
+        }
+
+        subgraph edge_worker_subgraph {
+            label="Edge site";
+            edge_worker[label="Edge Worker"]
+            edge_dag[label="DAG files (Remote)"]
+
+            edge_worker->edge_dag
+        }
+
+        edge_worker->web[label="HTTP(s)"]
+    }
+
+Airflow consist of several components:
+
+* **Workers** - Execute the assigned tasks - most standard setup has local or 
centralized workers, e.g. via Celery
+* **Edge Workers** - Special workers which pull tasks via HTTP as provided as 
feature via this provider package
+* **Scheduler** - Responsible for adding the necessary tasks to the queue
+* **Web server** - HTTP Server provides access to DAG/task status information
+* **Database** - Contains information about the status of tasks, DAGs, 
Variables, connections, etc.
+
+
+.. _edge_executor:queue:
+
+Queues
+------
+
+When using the EdgeExecutor, the workers that tasks are sent to
+can be specified. ``queue`` is an attribute of BaseOperator, so any
+task can be assigned to any queue. The default queue for the environment
+is defined in the ``airflow.cfg``'s ``operators -> default_queue``. This 
defines
+the queue that tasks get assigned to when not specified, as well as which
+queue Airflow workers listen to when started.
+
+Workers can listen to one or multiple queues of tasks. When a worker is
+started (using command ``airflow edge worker``), a set of comma-delimited queue
+names (with no whitespace) can be given (e.g. ``airflow edge worker -q 
remote,wisconsin_site``).
+This worker will then only pick up tasks wired to the specified queue(s).
+
+This can be useful if you need specialized workers, either from a
+resource perspective (for say very lightweight tasks where one worker
+could take thousands of tasks without a problem), or from an environment
+perspective (you want a worker running from a specific location where required
+infrastructure is available).
+
+Feature Backlog of MVP to Release Readiness
+-------------------------------------------
+
+As noted above the current version of the EdgeExecutor is a MVP (Minimum 
Viable Product).
+It can be used but must be taken with care if you want to use it productively. 
Just the
+bare minimum functions are provided currently and missing features will be 
added over time.
+
+The target implementation is sketched in
+`AIP-69 (Airflow Improvement Proposal for Edge Executor) 
<https://cwiki.apache.org/confluence/pages/viewpage.action?pageId=301795932>`_
+and this AIP will be completed when open features are implemented and it has 
production grade stability.
+
+The following features are known missing and will be implemented in increments:
+
+- API token per worker: Today there is a global API token available only
+- Edge Worker Plugin
+
+  - Overview about queues / jobs per queue
+  - Allow starting Edge Worker REST API separate to webserver
+  - Administrative maintenance / temporary disable jobs on worker
+
+- Edge Worker CLI
+
+  - Use WebSockets instead of HTTP calls for communication
+  - Handle SIG-INT/CTRL+C and gracefully terminate and complete job (``airflow 
edge stop`` is working though)
+  - Send logs also to TaskFileHandler if external logging services are used
+  - Integration into telemetry to send metrics from remote site
+  - Allow ``airflow edge stop`` to wait until completed to terminated
+  - Publish system metrics with heartbeats (CPU, Disk space, RAM, Load)
+  - Be more liberal e.g. on patch version. MVP requires exact version match
+
+- Tests
+
+  - Integration tests in Github
+  - Test/Support on Windows for Edge Worker
+
+- Scaling test - Check and define boundaries of workers/jobs
+- Load tests - impact of scaled execution and code optimization
+- Airflow 3 / AIP-72 Migration
+
+  - Thin deployment based on Task SDK
+  - DAG Code push (no need to GIT Sync)
+  - Implicit with AIP-72: Move task context generation from Remote to Executor
+
+- Documentation
+
+  - Describe more details on deployment options and tuning
+  - Provide scripts and guides to install edge components as service (systemd)
+  - Extend Helm-Chart for needed support
+    While it is in not-ready state, a wheel release package must be manually 
built from source tree
+    via ``breeze release-management prepare-provider-packages 
--include-not-ready-providers edge``
+    and then installed via pip from the generated wheel file.
diff --git a/docs/apache-airflow-providers-edge/index.rst 
b/docs/apache-airflow-providers-edge/index.rst
index 8b78170b34..9f456e8f89 100644
--- a/docs/apache-airflow-providers-edge/index.rst
+++ b/docs/apache-airflow-providers-edge/index.rst
@@ -30,6 +30,14 @@
     Security <security>
 
 
+.. toctree::
+    :hidden:
+    :maxdepth: 1
+    :caption: Executors
+
+    EdgeExecutor details <edge_executor>
+
+
 .. toctree::
     :hidden:
     :maxdepth: 1
diff --git a/providers/src/airflow/providers/edge/executors/__init__.py 
b/providers/src/airflow/providers/edge/executors/__init__.py
new file mode 100644
index 0000000000..1af34c51e0
--- /dev/null
+++ b/providers/src/airflow/providers/edge/executors/__init__.py
@@ -0,0 +1,22 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from airflow.providers.edge.executors.edge_executor import EdgeExecutor
+
+__all__ = ["EdgeExecutor"]
diff --git a/providers/src/airflow/providers/edge/executors/edge_executor.py 
b/providers/src/airflow/providers/edge/executors/edge_executor.py
new file mode 100644
index 0000000000..f673e8fa6f
--- /dev/null
+++ b/providers/src/airflow/providers/edge/executors/edge_executor.py
@@ -0,0 +1,185 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from datetime import datetime, timedelta
+from typing import TYPE_CHECKING, Any, Sequence
+
+from sqlalchemy import delete
+
+from airflow.cli.cli_config import GroupCommand
+from airflow.configuration import conf
+from airflow.executors.base_executor import BaseExecutor
+from airflow.models.abstractoperator import DEFAULT_QUEUE
+from airflow.models.taskinstance import TaskInstanceState
+from airflow.providers.edge.models.edge_job import EdgeJobModel
+from airflow.providers.edge.models.edge_logs import EdgeLogsModel
+from airflow.providers.edge.models.edge_worker import EdgeWorkerModel
+from airflow.utils.db import DBLocks, create_global_lock
+from airflow.utils.session import NEW_SESSION, provide_session
+
+if TYPE_CHECKING:
+    import argparse
+
+    from sqlalchemy.orm import Session
+
+    from airflow.executors.base_executor import CommandType
+    from airflow.models.taskinstance import TaskInstance
+    from airflow.models.taskinstancekey import TaskInstanceKey
+
+PARALLELISM: int = conf.getint("core", "PARALLELISM")
+
+
+class EdgeExecutor(BaseExecutor):
+    """Implementation of the EdgeExecutor to distribute work to Edge Workers 
via HTTP."""
+
+    def __init__(self, parallelism: int = PARALLELISM):
+        super().__init__(parallelism=parallelism)
+        self.last_reported_state: dict[TaskInstanceKey, TaskInstanceState] = {}
+
+    @provide_session
+    def start(self, session: Session = NEW_SESSION):
+        """If EdgeExecutor provider is loaded first time, ensure table 
exists."""
+        with create_global_lock(session=session, lock=DBLocks.MIGRATIONS):
+            engine = session.get_bind().engine
+            EdgeJobModel.metadata.create_all(engine)
+            EdgeLogsModel.metadata.create_all(engine)
+            EdgeWorkerModel.metadata.create_all(engine)
+
+    @provide_session
+    def execute_async(
+        self,
+        key: TaskInstanceKey,
+        command: CommandType,
+        queue: str | None = None,
+        executor_config: Any | None = None,
+        session: Session = NEW_SESSION,
+    ) -> None:
+        """Execute asynchronously."""
+        self.validate_airflow_tasks_run_command(command)
+        session.add(
+            EdgeJobModel(
+                dag_id=key.dag_id,
+                task_id=key.task_id,
+                run_id=key.run_id,
+                map_index=key.map_index,
+                try_number=key.try_number,
+                state=TaskInstanceState.QUEUED,
+                queue=queue or DEFAULT_QUEUE,
+                command=str(command),
+            )
+        )
+
+    @provide_session
+    def sync(self, session: Session = NEW_SESSION) -> None:
+        """Sync will get called periodically by the heartbeat method."""
+        purged_marker = False
+        job_success_purge = conf.getint("edge", "job_success_purge")
+        job_fail_purge = conf.getint("edge", "job_fail_purge")
+        jobs: list[EdgeJobModel] = (
+            session.query(EdgeJobModel)
+            .filter(
+                EdgeJobModel.state.in_(
+                    [TaskInstanceState.RUNNING, TaskInstanceState.SUCCESS, 
TaskInstanceState.FAILED]
+                )
+            )
+            .all()
+        )
+        for job in jobs:
+            if job.key in self.running:
+                if job.state == TaskInstanceState.RUNNING:
+                    if (
+                        job.key not in self.last_reported_state
+                        or self.last_reported_state[job.key] != job.state
+                    ):
+                        self.running_state(job.key)
+                    self.last_reported_state[job.key] = job.state
+                elif job.state == TaskInstanceState.SUCCESS:
+                    if job.key in self.last_reported_state:
+                        del self.last_reported_state[job.key]
+                    self.success(job.key)
+                elif job.state == TaskInstanceState.FAILED:
+                    if job.key in self.last_reported_state:
+                        del self.last_reported_state[job.key]
+                    self.fail(job.key)
+                else:
+                    self.last_reported_state[job.key] = job.state
+            if (
+                job.state == TaskInstanceState.SUCCESS
+                and job.last_update_t < (datetime.now() - 
timedelta(minutes=job_success_purge)).timestamp()
+            ) or (
+                job.state == TaskInstanceState.FAILED
+                and job.last_update_t < (datetime.now() - 
timedelta(minutes=job_fail_purge)).timestamp()
+            ):
+                if job.key in self.last_reported_state:
+                    del self.last_reported_state[job.key]
+                purged_marker = True
+                session.delete(job)
+                session.execute(
+                    delete(EdgeLogsModel).where(
+                        EdgeLogsModel.dag_id == job.dag_id,
+                        EdgeLogsModel.run_id == job.run_id,
+                        EdgeLogsModel.task_id == job.task_id,
+                        EdgeLogsModel.map_index == job.map_index,
+                        EdgeLogsModel.try_number == job.try_number,
+                    )
+                )
+        if purged_marker:
+            session.commit()
+
+    def end(self) -> None:
+        """End the executor."""
+        self.log.info("Shutting down EdgeExecutor")
+
+    def terminate(self):
+        """Terminate the executor is not doing anything."""
+
+    def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> 
Sequence[TaskInstance]:
+        """
+        Try to adopt running task instances that have been abandoned by a 
SchedulerJob dying.
+
+        Anything that is not adopted will be cleared by the scheduler (and 
then become eligible for
+        re-scheduling)
+
+        :return: any TaskInstances that were unable to be adopted
+        """
+        # We handle all running tasks from the DB in sync, no adoption logic 
needed.
+        return []
+
+    @staticmethod
+    def get_cli_commands() -> list[GroupCommand]:
+        return [
+            GroupCommand(
+                name="edge",
+                help="Edge Worker components",
+                description=(
+                    "Start and manage Edge Worker. Works only when using 
EdgeExecutor. For more information, "
+                    "see 
https://airflow.apache.org/docs/apache-airflow-providers-edge/stable/edge_executor.html";
+                ),
+                subcommands=[],
+            ),
+        ]
+
+
+def _get_parser() -> argparse.ArgumentParser:
+    """
+    Generate documentation; used by Sphinx.
+
+    :meta private:
+    """
+    return EdgeExecutor._get_parser()
diff --git a/providers/src/airflow/providers/edge/provider.yaml 
b/providers/src/airflow/providers/edge/provider.yaml
index d6644271a0..7f61ad128c 100644
--- a/providers/src/airflow/providers/edge/provider.yaml
+++ b/providers/src/airflow/providers/edge/provider.yaml
@@ -36,6 +36,9 @@ plugins:
   - name: edge_executor
     plugin-class: 
airflow.providers.edge.plugins.edge_executor_plugin.EdgeExecutorPlugin
 
+executors:
+  - airflow.providers.edge.executors.EdgeExecutor
+
 config:
   edge:
     description: |
diff --git a/providers/tests/edge/executors/__init__.py 
b/providers/tests/edge/executors/__init__.py
new file mode 100644
index 0000000000..217e5db960
--- /dev/null
+++ b/providers/tests/edge/executors/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/providers/tests/edge/executors/test_edge_executor.py 
b/providers/tests/edge/executors/test_edge_executor.py
new file mode 100644
index 0000000000..de45384fd7
--- /dev/null
+++ b/providers/tests/edge/executors/test_edge_executor.py
@@ -0,0 +1,152 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import patch
+
+import pytest
+
+from airflow.models.taskinstancekey import TaskInstanceKey
+from airflow.providers.edge.executors.edge_executor import EdgeExecutor
+from airflow.providers.edge.models.edge_job import EdgeJobModel
+from airflow.utils import timezone
+from airflow.utils.session import create_session
+from airflow.utils.state import TaskInstanceState
+
+pytestmark = pytest.mark.db_test
+
+
+class TestEdgeExecutor:
+    @pytest.fixture(autouse=True)
+    def setup_test_cases(self):
+        with create_session() as session:
+            session.query(EdgeJobModel).delete()
+
+    def test_execute_async_bad_command(self):
+        executor = EdgeExecutor()
+        with pytest.raises(ValueError):
+            executor.execute_async(
+                TaskInstanceKey(
+                    dag_id="test_dag", run_id="test_run", task_id="test_task", 
map_index=-1, try_number=1
+                ),
+                command=["hello", "world"],
+            )
+
+    def test_execute_async_ok_command(self):
+        executor = EdgeExecutor()
+        executor.execute_async(
+            TaskInstanceKey(
+                dag_id="test_dag", run_id="test_run", task_id="test_task", 
map_index=-1, try_number=1
+            ),
+            command=["airflow", "tasks", "run", "hello", "world"],
+        )
+        with create_session() as session:
+            jobs: list[EdgeJobModel] = session.query(EdgeJobModel).all()
+        assert len(jobs) == 1
+        assert jobs[0].dag_id == "test_dag"
+        assert jobs[0].run_id == "test_run"
+        assert jobs[0].task_id == "test_task"
+
+    
@patch("airflow.providers.edge.executors.edge_executor.EdgeExecutor.running_state")
+    
@patch("airflow.providers.edge.executors.edge_executor.EdgeExecutor.success")
+    @patch("airflow.providers.edge.executors.edge_executor.EdgeExecutor.fail")
+    def test_sync(self, mock_fail, mock_success, mock_running_state):
+        executor = EdgeExecutor()
+
+        def remove_from_running(key: TaskInstanceKey):
+            executor.running.remove(key)
+
+        mock_success.side_effect = remove_from_running
+        mock_fail.side_effect = remove_from_running
+
+        # Prepare some data
+        with create_session() as session:
+            for task_id, state in [
+                ("started_running", TaskInstanceState.RUNNING),
+                ("started_success", TaskInstanceState.SUCCESS),
+                ("started_failed", TaskInstanceState.FAILED),
+            ]:
+                session.add(
+                    EdgeJobModel(
+                        dag_id="test_dag",
+                        task_id=task_id,
+                        run_id="test_run",
+                        map_index=-1,
+                        try_number=1,
+                        state=state,
+                        queue="default",
+                        command="dummy",
+                        last_update=timezone.utcnow(),
+                    )
+                )
+                key = TaskInstanceKey(
+                    dag_id="test_dag", run_id="test_run", task_id=task_id, 
map_index=-1, try_number=1
+                )
+                executor.running.add(key)
+                session.commit()
+        assert len(executor.running) == 3
+
+        executor.sync()
+
+        assert len(executor.running) == 1
+        mock_running_state.assert_called_once()
+        mock_success.assert_called_once()
+        mock_fail.assert_called_once()
+
+        # Any test another round with one new run
+        mock_running_state.reset_mock()
+        mock_success.reset_mock()
+        mock_fail.reset_mock()
+
+        with create_session() as session:
+            task_id = "started_running2"
+            state = TaskInstanceState.RUNNING
+            session.add(
+                EdgeJobModel(
+                    dag_id="test_dag",
+                    task_id=task_id,
+                    run_id="test_run",
+                    map_index=-1,
+                    try_number=1,
+                    state=state,
+                    queue="default",
+                    command="dummy",
+                    last_update=timezone.utcnow(),
+                )
+            )
+            key = TaskInstanceKey(
+                dag_id="test_dag", run_id="test_run", task_id=task_id, 
map_index=-1, try_number=1
+            )
+            executor.running.add(key)
+            session.commit()
+        assert len(executor.running) == 2
+
+        executor.sync()
+
+        assert len(executor.running) == 2
+        mock_running_state.assert_called_once()  # because we reported already 
first run, new run calling
+        mock_success.assert_not_called()  # because we reported already, not 
running anymore
+        mock_fail.assert_not_called()  # because we reported already, not 
running anymore
+        mock_running_state.reset_mock()
+
+        executor.sync()
+
+        assert len(executor.running) == 2
+        # Now none is called as we called before already
+        mock_running_state.assert_not_called()
+        mock_success.assert_not_called()
+        mock_fail.assert_not_called()

Reply via email to