This is an automated email from the ASF dual-hosted git repository.

jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 788b9c486b Add DB models for Edge Provider (#42047)
788b9c486b is described below

commit 788b9c486bf9e42fb4b10a30edef7f536bb873d6
Author: Jens Scheffler <[email protected]>
AuthorDate: Mon Sep 16 21:55:56 2024 +0200

    Add DB models for Edge Provider (#42047)
---
 airflow/providers/edge/example_dags/__init__.py    |  16 ++
 .../edge/example_dags/integration_test.py          | 139 ++++++++++++++
 airflow/providers/edge/models/__init__.py          |  16 ++
 airflow/providers/edge/models/edge_job.py          | 185 +++++++++++++++++++
 airflow/providers/edge/models/edge_logs.py         | 153 ++++++++++++++++
 airflow/providers/edge/models/edge_worker.py       | 203 +++++++++++++++++++++
 tests/providers/edge/models/__init__.py            |  17 ++
 tests/providers/edge/models/test_edge_job.py       |  99 ++++++++++
 tests/providers/edge/models/test_edge_logs.py      |  49 +++++
 tests/providers/edge/models/test_edge_worker.py    |  65 +++++++
 10 files changed, 942 insertions(+)

diff --git a/airflow/providers/edge/example_dags/__init__.py 
b/airflow/providers/edge/example_dags/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/airflow/providers/edge/example_dags/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/edge/example_dags/integration_test.py 
b/airflow/providers/edge/example_dags/integration_test.py
new file mode 100644
index 0000000000..d6074abd30
--- /dev/null
+++ b/airflow/providers/edge/example_dags/integration_test.py
@@ -0,0 +1,139 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+In this DAG all critical functions as integration test are contained.
+
+The DAG should work in all standard setups without error.
+"""
+
+from __future__ import annotations
+
+from datetime import datetime
+from time import sleep
+
+from airflow.decorators import task, task_group
+from airflow.exceptions import AirflowNotFoundException
+from airflow.hooks.base import BaseHook
+from airflow.models.dag import DAG
+from airflow.models.param import Param
+from airflow.models.variable import Variable
+from airflow.operators.bash import BashOperator
+from airflow.operators.empty import EmptyOperator
+from airflow.operators.python import PythonOperator
+
+with DAG(
+    dag_id="integration_test",
+    dag_display_name="Integration Test",
+    description=__doc__.partition(".")[0],
+    doc_md=__doc__,
+    schedule=None,
+    start_date=datetime(2024, 7, 1),
+    tags=["example", "params", "integration test"],
+    params={
+        "mapping_count": Param(
+            4,
+            type="integer",
+            title="Mapping Count",
+            description="Amount of tasks that should be mapped",
+        ),
+    },
+) as dag:
+
+    @task
+    def my_setup():
+        print("Assume this is a setup task")
+
+    @task
+    def mapping_from_params(**context) -> list[int]:
+        mapping_count: int = context["params"]["mapping_count"]
+        return list(range(1, mapping_count + 1))
+
+    @task
+    def add_one(x: int):
+        return x + 1
+
+    @task
+    def sum_it(values):
+        total = sum(values)
+        print(f"Total was {total}")
+
+    @task_group(prefix_group_id=False)
+    def mapping_task_group():
+        added_values = add_one.expand(x=mapping_from_params())
+        sum_it(added_values)
+
+    @task.branch
+    def branching():
+        return ["bash", "virtualenv", "variable", "connection", 
"classic_bash", "classic_python"]
+
+    @task.bash
+    def bash():
+        return "echo hello world"
+
+    @task.virtualenv(requirements="numpy")
+    def virtualenv():
+        import numpy
+
+        print(f"Welcome to virtualenv with numpy version {numpy.__version__}.")
+
+    @task
+    def variable():
+        Variable.set("integration_test_key", "value")
+        assert Variable.get("integration_test_key") == "value"  # noqa: S101
+        Variable.delete("integration_test_key")
+
+    @task
+    def connection():
+        try:
+            conn = BaseHook.get_connection("integration_test")
+            print(f"Got connection {conn}")
+        except AirflowNotFoundException:
+            print("Connection not found... but also OK.")
+
+    @task_group(prefix_group_id=False)
+    def standard_tasks_group():
+        classic_bash = BashOperator(
+            task_id="classic_bash", bash_command="echo Parameter is {{ 
params.mapping_count }}"
+        )
+
+        empty = EmptyOperator(task_id="not_executed")
+
+        def python_call():
+            print("Hello world")
+
+        classic_py = PythonOperator(task_id="classic_python", 
python_callable=python_call)
+
+        branching() >> [bash(), virtualenv(), variable(), connection(), 
classic_bash, classic_py, empty]
+
+    @task
+    def long_running():
+        print("This task runs for 15 minutes")
+        for i in range(15):
+            sleep(60)
+            print(f"Running for {i + 1} minutes now.")
+        print("Long running task completed.")
+
+    @task
+    def my_teardown():
+        print("Assume this is a teardown task")
+
+    (
+        my_setup().as_setup()
+        >> [mapping_task_group(), standard_tasks_group(), long_running()]
+        >> my_teardown().as_teardown()
+    )
diff --git a/airflow/providers/edge/models/__init__.py 
b/airflow/providers/edge/models/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/airflow/providers/edge/models/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/edge/models/edge_job.py 
b/airflow/providers/edge/models/edge_job.py
new file mode 100644
index 0000000000..b6e316e1f7
--- /dev/null
+++ b/airflow/providers/edge/models/edge_job.py
@@ -0,0 +1,185 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from ast import literal_eval
+from datetime import datetime
+from typing import TYPE_CHECKING, List, Optional
+
+from pydantic import BaseModel, ConfigDict
+from sqlalchemy import (
+    Column,
+    Index,
+    Integer,
+    String,
+    select,
+    text,
+)
+
+from airflow.api_internal.internal_api_call import internal_api_call
+from airflow.models.base import Base, StringID
+from airflow.models.taskinstancekey import TaskInstanceKey
+from airflow.serialization.serialized_objects import 
add_pydantic_class_type_mapping
+from airflow.utils import timezone
+from airflow.utils.log.logging_mixin import LoggingMixin
+from airflow.utils.session import NEW_SESSION, provide_session
+from airflow.utils.sqlalchemy import UtcDateTime, with_row_locks
+from airflow.utils.state import TaskInstanceState
+
+if TYPE_CHECKING:
+    from sqlalchemy.orm.session import Session
+
+
+class EdgeJobModel(Base, LoggingMixin):
+    """
+    A job which is queued, waiting or running on a Edge Worker.
+
+    Each tuple in the database represents and describes the state of one job.
+    """
+
+    __tablename__ = "edge_job"
+    dag_id = Column(StringID(), primary_key=True, nullable=False)
+    task_id = Column(StringID(), primary_key=True, nullable=False)
+    run_id = Column(StringID(), primary_key=True, nullable=False)
+    map_index = Column(Integer, primary_key=True, nullable=False, 
server_default=text("-1"))
+    try_number = Column(Integer, primary_key=True, default=0)
+    state = Column(String(20))
+    queue = Column(String(256))
+    command = Column(String(1000))
+    queued_dttm = Column(UtcDateTime)
+    edge_worker = Column(String(64))
+    last_update = Column(UtcDateTime)
+
+    def __init__(
+        self,
+        dag_id: str,
+        task_id: str,
+        run_id: str,
+        map_index: int,
+        try_number: int,
+        state: str,
+        queue: str,
+        command: str,
+        queued_dttm: datetime | None = None,
+        edge_worker: str | None = None,
+        last_update: datetime | None = None,
+    ):
+        self.dag_id = dag_id
+        self.task_id = task_id
+        self.run_id = run_id
+        self.map_index = map_index
+        self.try_number = try_number
+        self.state = state
+        self.queue = queue
+        self.command = command
+        self.queued_dttm = queued_dttm or timezone.utcnow()
+        self.edge_worker = edge_worker
+        self.last_update = last_update
+        super().__init__()
+
+    __table_args__ = (Index("rj_order", state, queued_dttm, queue),)
+
+    @property
+    def key(self):
+        return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, 
self.try_number, self.map_index)
+
+    @property
+    def last_update_t(self) -> float:
+        return self.last_update.timestamp()
+
+
+class EdgeJob(BaseModel, LoggingMixin):
+    """Accessor for edge jobs as logical model."""
+
+    dag_id: str
+    task_id: str
+    run_id: str
+    map_index: int
+    try_number: int
+    state: TaskInstanceState
+    queue: str
+    command: List[str]  # noqa: UP006 - prevent Sphinx failing
+    queued_dttm: datetime
+    edge_worker: Optional[str]  # noqa: UP007 - prevent Sphinx failing
+    last_update: Optional[datetime]  # noqa: UP007 - prevent Sphinx failing
+    model_config = ConfigDict(from_attributes=True, 
arbitrary_types_allowed=True)
+
+    @property
+    def key(self) -> TaskInstanceKey:
+        return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, 
self.try_number, self.map_index)
+
+    @staticmethod
+    @internal_api_call
+    @provide_session
+    def reserve_task(
+        worker_name: str, queues: list[str] | None = None, session: Session = 
NEW_SESSION
+    ) -> EdgeJob | None:
+        query = (
+            select(EdgeJobModel)
+            .where(EdgeJobModel.state == TaskInstanceState.QUEUED)
+            .order_by(EdgeJobModel.queued_dttm)
+        )
+        if queues:
+            query = query.where(EdgeJobModel.queue.in_(queues))
+        query = query.limit(1)
+        query = with_row_locks(query, of=EdgeJobModel, session=session, 
skip_locked=True)
+        job: EdgeJobModel = session.scalar(query)
+        if not job:
+            return None
+        job.state = TaskInstanceState.RUNNING
+        job.edge_worker = worker_name
+        job.last_update = timezone.utcnow()
+        session.commit()
+        return EdgeJob(
+            dag_id=job.dag_id,
+            task_id=job.task_id,
+            run_id=job.run_id,
+            map_index=job.map_index,
+            try_number=job.try_number,
+            state=job.state,
+            queue=job.queue,
+            command=literal_eval(job.command),
+            queued_dttm=job.queued_dttm,
+            edge_worker=job.edge_worker,
+            last_update=job.last_update,
+        )
+
+    @staticmethod
+    @internal_api_call
+    @provide_session
+    def set_state(task: TaskInstanceKey | tuple, state: TaskInstanceState, 
session: Session = NEW_SESSION):
+        if isinstance(task, tuple):
+            task = TaskInstanceKey(*task)
+        query = select(EdgeJobModel).where(
+            EdgeJobModel.dag_id == task.dag_id,
+            EdgeJobModel.task_id == task.task_id,
+            EdgeJobModel.run_id == task.run_id,
+            EdgeJobModel.map_index == task.map_index,
+            EdgeJobModel.try_number == task.try_number,
+        )
+        job: EdgeJobModel = session.scalar(query)
+        job.state = state
+        job.last_update = timezone.utcnow()
+        session.commit()
+
+    def __hash__(self):
+        return 
f"{self.dag_id}|{self.task_id}|{self.run_id}|{self.map_index}|{self.try_number}".__hash__()
+
+
+EdgeJob.model_rebuild()
+
+add_pydantic_class_type_mapping("edge_job", EdgeJobModel, EdgeJob)
diff --git a/airflow/providers/edge/models/edge_logs.py 
b/airflow/providers/edge/models/edge_logs.py
new file mode 100644
index 0000000000..29625f5be7
--- /dev/null
+++ b/airflow/providers/edge/models/edge_logs.py
@@ -0,0 +1,153 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from datetime import datetime
+from functools import lru_cache
+from pathlib import Path
+from typing import TYPE_CHECKING
+
+from pydantic import BaseModel, ConfigDict
+from sqlalchemy import (
+    Column,
+    Integer,
+    Text,
+    text,
+)
+from sqlalchemy.dialects.mysql import MEDIUMTEXT
+
+from airflow.api_internal.internal_api_call import internal_api_call
+from airflow.configuration import conf
+from airflow.models.base import Base, StringID
+from airflow.models.taskinstance import TaskInstance
+from airflow.models.taskinstancekey import TaskInstanceKey
+from airflow.serialization.serialized_objects import 
add_pydantic_class_type_mapping
+from airflow.utils.log.logging_mixin import LoggingMixin
+from airflow.utils.session import NEW_SESSION, provide_session
+from airflow.utils.sqlalchemy import UtcDateTime
+
+if TYPE_CHECKING:
+    from sqlalchemy.orm.session import Session
+
+
+class EdgeLogsModel(Base, LoggingMixin):
+    """
+    Temporary collected logs from a Edge Worker while job runs on remote site.
+
+    As the Edge Worker in most cases has a local file system and the web UI no 
access
+    to read files from remote site, Edge Workers will send incremental chunks 
of logs
+    of running jobs to the central site. As log storage backends in most cloud 
cases can not
+    append logs, the table is used as buffer to receive. Upon task completion 
logs can be
+    flushed to task log handler.
+
+    Log data therefore is collected in chunks and is only temporary.
+    """
+
+    __tablename__ = "edge_logs"
+    dag_id = Column(StringID(), primary_key=True, nullable=False)
+    task_id = Column(StringID(), primary_key=True, nullable=False)
+    run_id = Column(StringID(), primary_key=True, nullable=False)
+    map_index = Column(Integer, primary_key=True, nullable=False, 
server_default=text("-1"))
+    try_number = Column(Integer, primary_key=True, default=0)
+    log_chunk_time = Column(UtcDateTime, primary_key=True, nullable=False)
+    log_chunk_data = Column(Text().with_variant(MEDIUMTEXT(), "mysql"), 
nullable=False)
+
+    def __init__(
+        self,
+        dag_id: str,
+        task_id: str,
+        run_id: str,
+        map_index: int,
+        try_number: int,
+        log_chunk_time: datetime,
+        log_chunk_data: str,
+    ):
+        self.dag_id = dag_id
+        self.task_id = task_id
+        self.run_id = run_id
+        self.map_index = map_index
+        self.try_number = try_number
+        self.log_chunk_time = log_chunk_time
+        self.log_chunk_data = log_chunk_data
+        super().__init__()
+
+
+class EdgeLogs(BaseModel, LoggingMixin):
+    """Accessor for Edge Worker instances as logical model."""
+
+    dag_id: str
+    task_id: str
+    run_id: str
+    map_index: int
+    try_number: int
+    log_chunk_time: datetime
+    log_chunk_data: str
+    model_config = ConfigDict(from_attributes=True, 
arbitrary_types_allowed=True)
+
+    @staticmethod
+    @internal_api_call
+    @provide_session
+    def push_logs(
+        task: TaskInstanceKey | tuple,
+        log_chunk_time: datetime,
+        log_chunk_data: str,
+        session: Session = NEW_SESSION,
+    ) -> None:
+        """Push an incremental log chunk from Edge Worker to central site."""
+        if isinstance(task, tuple):
+            task = TaskInstanceKey(*task)
+        log_chunk = EdgeLogsModel(
+            dag_id=task.dag_id,
+            task_id=task.task_id,
+            run_id=task.run_id,
+            map_index=task.map_index,
+            try_number=task.try_number,
+            log_chunk_time=log_chunk_time,
+            log_chunk_data=log_chunk_data,
+        )
+        session.add(log_chunk)
+        # Write logs to local file to make them accessible
+        logfile_path = EdgeLogs.logfile_path(task)
+        if not logfile_path.exists():
+            new_folder_permissions = int(
+                conf.get("logging", 
"file_task_handler_new_folder_permissions", fallback="0o775"), 8
+            )
+            logfile_path.parent.mkdir(parents=True, exist_ok=True, 
mode=new_folder_permissions)
+        with logfile_path.open("a") as logfile:
+            logfile.write(log_chunk_data)
+
+    @staticmethod
+    @lru_cache
+    def logfile_path(task: TaskInstanceKey) -> Path:
+        """Elaborate the path and filename to expect from task execution."""
+        from airflow.utils.log.file_task_handler import FileTaskHandler
+
+        ti = TaskInstance.get_task_instance(
+            dag_id=task.dag_id,
+            run_id=task.run_id,
+            task_id=task.task_id,
+            map_index=task.map_index,
+        )
+        if TYPE_CHECKING:
+            assert ti
+        base_log_folder = conf.get("logging", "base_log_folder", fallback="NOT 
AVAILABLE")
+        return Path(base_log_folder, 
FileTaskHandler(base_log_folder)._render_filename(ti, task.try_number))
+
+
+EdgeLogs.model_rebuild()
+
+add_pydantic_class_type_mapping("edge_logs", EdgeLogsModel, EdgeLogs)
diff --git a/airflow/providers/edge/models/edge_worker.py 
b/airflow/providers/edge/models/edge_worker.py
new file mode 100644
index 0000000000..193795e37d
--- /dev/null
+++ b/airflow/providers/edge/models/edge_worker.py
@@ -0,0 +1,203 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import json
+from datetime import datetime
+from enum import Enum
+from typing import TYPE_CHECKING, List, Optional
+
+from pydantic import BaseModel, ConfigDict
+from sqlalchemy import (
+    Column,
+    Integer,
+    String,
+    select,
+)
+
+from airflow.api_internal.internal_api_call import internal_api_call
+from airflow.exceptions import AirflowException
+from airflow.models.base import Base
+from airflow.serialization.serialized_objects import 
add_pydantic_class_type_mapping
+from airflow.utils import timezone
+from airflow.utils.log.logging_mixin import LoggingMixin
+from airflow.utils.session import NEW_SESSION, provide_session
+from airflow.utils.sqlalchemy import UtcDateTime
+
+if TYPE_CHECKING:
+    from sqlalchemy.orm.session import Session
+
+
+class EdgeWorkerVersionException(AirflowException):
+    """Signal a version mismatch between core and Edge Site."""
+
+    pass
+
+
+class EdgeWorkerState(str, Enum):
+    """Status of a Edge Worker instance."""
+
+    STARTING = "starting"
+    """Edge Worker is in initialization."""
+    RUNNING = "running"
+    """Edge Worker is actively running a task."""
+    IDLE = "idle"
+    """Edge Worker is active and waiting for a task."""
+    TERMINATING = "terminating"
+    """Edge Worker is completing work and stopping."""
+    OFFLINE = "offline"
+    """Edge Worker was show down."""
+    UNKNOWN = "unknown"
+    """No heartbeat signal from worker for some time, Edge Worker probably 
down."""
+
+
+class EdgeWorkerModel(Base, LoggingMixin):
+    """A Edge Worker instance which reports the state and health."""
+
+    __tablename__ = "edge_worker"
+    worker_name = Column(String(64), primary_key=True, nullable=False)
+    state = Column(String(20))
+    queues = Column(String(256))
+    first_online = Column(UtcDateTime)
+    last_update = Column(UtcDateTime)
+    jobs_active = Column(Integer, default=0)
+    jobs_taken = Column(Integer, default=0)
+    jobs_success = Column(Integer, default=0)
+    jobs_failed = Column(Integer, default=0)
+    sysinfo = Column(String(256))
+
+    def __init__(
+        self,
+        worker_name: str,
+        state: str,
+        queues: list[str] | None,
+        first_online: datetime | None = None,
+        last_update: datetime | None = None,
+    ):
+        self.worker_name = worker_name
+        self.state = state
+        self.queues = ", ".join(queues) if queues else None
+        self.first_online = first_online or timezone.utcnow()
+        self.last_update = last_update
+        super().__init__()
+
+    @property
+    def sysinfo_json(self) -> dict:
+        return json.loads(self.sysinfo) if self.sysinfo else None
+
+
+class EdgeWorker(BaseModel, LoggingMixin):
+    """Accessor for Edge Worker instances as logical model."""
+
+    worker_name: str
+    state: EdgeWorkerState
+    queues: Optional[List[str]]  # noqa: UP006,UP007 - prevent Sphinx failing
+    first_online: datetime
+    last_update: Optional[datetime] = None  # noqa: UP007 - prevent Sphinx 
failing
+    jobs_active: int
+    jobs_taken: int
+    jobs_success: int
+    jobs_failed: int
+    sysinfo: str
+    model_config = ConfigDict(from_attributes=True, 
arbitrary_types_allowed=True)
+
+    @staticmethod
+    def assert_version(sysinfo: dict[str, str]) -> None:
+        """Check if the Edge Worker version matches the central API site."""
+        from airflow import __version__ as airflow_version
+        from airflow.providers.edge import __version__ as edge_provider_version
+
+        # Note: In future, more stable versions we might be more liberate, for 
the
+        #       moment we require exact version match for Edge Worker and core 
version
+        if "airflow_version" in sysinfo:
+            airflow_on_worker = sysinfo["airflow_version"]
+            if airflow_on_worker != airflow_version:
+                raise EdgeWorkerVersionException(
+                    f"Edge Worker runs on Airflow {airflow_on_worker} "
+                    f"and the core runs on {airflow_version}. Rejecting access 
due to difference."
+                )
+        else:
+            raise EdgeWorkerVersionException("Edge Worker does not specify the 
version it is running on.")
+
+        if "edge_provider_version" in sysinfo:
+            provider_on_worker = sysinfo["edge_provider_version"]
+            if provider_on_worker != edge_provider_version:
+                raise EdgeWorkerVersionException(
+                    f"Edge Worker runs on Edge Provider {provider_on_worker} "
+                    f"and the core runs on {edge_provider_version}. Rejecting 
access due to difference."
+                )
+        else:
+            raise EdgeWorkerVersionException(
+                "Edge Worker does not specify the provider version it is 
running on."
+            )
+
+    @staticmethod
+    @internal_api_call
+    @provide_session
+    def register_worker(
+        worker_name: str,
+        state: EdgeWorkerState,
+        queues: list[str] | None,
+        sysinfo: dict[str, str],
+        session: Session = NEW_SESSION,
+    ) -> EdgeWorker:
+        EdgeWorker.assert_version(sysinfo)
+        query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == 
worker_name)
+        worker: EdgeWorkerModel = session.scalar(query)
+        if not worker:
+            worker = EdgeWorkerModel(worker_name=worker_name, state=state, 
queues=queues)
+        worker.state = state
+        worker.queues = queues
+        worker.sysinfo = json.dumps(sysinfo)
+        worker.last_update = timezone.utcnow()
+        session.add(worker)
+        return EdgeWorker(
+            worker_name=worker_name,
+            state=state,
+            queues=worker.queues,
+            first_online=worker.first_online,
+            last_update=worker.last_update,
+            jobs_active=worker.jobs_active or 0,
+            jobs_taken=worker.jobs_taken or 0,
+            jobs_success=worker.jobs_success or 0,
+            jobs_failed=worker.jobs_failed or 0,
+            sysinfo=worker.sysinfo or "{}",
+        )
+
+    @staticmethod
+    @internal_api_call
+    @provide_session
+    def set_state(
+        worker_name: str,
+        state: EdgeWorkerState,
+        jobs_active: int,
+        sysinfo: dict[str, str],
+        session: Session = NEW_SESSION,
+    ):
+        EdgeWorker.assert_version(sysinfo)
+        query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == 
worker_name)
+        worker: EdgeWorkerModel = session.scalar(query)
+        worker.state = state
+        worker.jobs_active = jobs_active
+        worker.sysinfo = json.dumps(sysinfo)
+        worker.last_update = timezone.utcnow()
+        session.commit()
+
+
+EdgeWorker.model_rebuild()
+
+add_pydantic_class_type_mapping("edge_worker", EdgeWorkerModel, EdgeWorker)
diff --git a/tests/providers/edge/models/__init__.py 
b/tests/providers/edge/models/__init__.py
new file mode 100644
index 0000000000..217e5db960
--- /dev/null
+++ b/tests/providers/edge/models/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/edge/models/test_edge_job.py 
b/tests/providers/edge/models/test_edge_job.py
new file mode 100644
index 0000000000..91dd897deb
--- /dev/null
+++ b/tests/providers/edge/models/test_edge_job.py
@@ -0,0 +1,99 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import pytest
+
+from airflow.providers.edge.models.edge_job import EdgeJob, EdgeJobModel
+from airflow.utils import timezone
+from airflow.utils.state import TaskInstanceState
+
+if TYPE_CHECKING:
+    from sqlalchemy.orm import Session
+
+pytestmark = pytest.mark.db_test
+pytest.importorskip("pydantic", minversion="2.0.0")
+
+
+class TestEdgeJob:
+    @pytest.fixture(autouse=True)
+    def setup_test_cases(self, session: Session):
+        session.query(EdgeJobModel).delete()
+
+    def test_reserve_task_no_job(self):
+        job = EdgeJob.reserve_task("worker")
+        assert job is None
+
+    def test_reserve_task_has_one(self, session: Session):
+        rjm = EdgeJobModel(
+            dag_id="test_dag",
+            task_id="test_task",
+            run_id="test_run",
+            map_index=-1,
+            try_number=1,
+            state=TaskInstanceState.QUEUED,
+            queue="default",
+            command=str(["hello", "world"]),
+            queued_dttm=timezone.utcnow(),
+        )
+        session.add(rjm)
+        session.commit()
+
+        job = EdgeJob.reserve_task("worker")
+        assert job
+        assert job.edge_worker == "worker"
+        assert job.queue == "default"
+        assert job.dag_id == "test_dag"
+        assert job.task_id == "test_task"
+        assert job.run_id == "test_run"
+
+        jobs: list[EdgeJobModel] = session.query(EdgeJobModel).all()
+        assert len(jobs) == 1
+        assert jobs[0].state == TaskInstanceState.RUNNING
+        assert jobs[0].edge_worker == "worker"
+        assert jobs[0].queue == "default"
+        assert jobs[0].dag_id == "test_dag"
+        assert jobs[0].task_id == "test_task"
+        assert jobs[0].run_id == "test_run"
+
+    def test_set_state(self, session: Session):
+        rjm = EdgeJobModel(
+            dag_id="test_dag",
+            task_id="test_task",
+            run_id="test_run",
+            map_index=-1,
+            try_number=1,
+            state=TaskInstanceState.RUNNING,
+            queue="default",
+            command=str(["hello", "world"]),
+            queued_dttm=timezone.utcnow(),
+        )
+        session.add(rjm)
+        session.commit()
+
+        EdgeJob.set_state(rjm.key, TaskInstanceState.FAILED)
+
+        jobs: list[EdgeJobModel] = session.query(EdgeJobModel).all()
+        assert len(jobs) == 1
+        assert jobs[0].state == TaskInstanceState.FAILED
+        assert jobs[0].last_update
+        assert jobs[0].queue == "default"
+        assert jobs[0].dag_id == "test_dag"
+        assert jobs[0].task_id == "test_task"
+        assert jobs[0].run_id == "test_run"
diff --git a/tests/providers/edge/models/test_edge_logs.py 
b/tests/providers/edge/models/test_edge_logs.py
new file mode 100644
index 0000000000..0bb3307e0c
--- /dev/null
+++ b/tests/providers/edge/models/test_edge_logs.py
@@ -0,0 +1,49 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pytest
+
+from airflow.providers.edge.models.edge_logs import EdgeLogs, EdgeLogsModel
+from airflow.utils import timezone
+
+pytestmark = pytest.mark.db_test
+
+pytest.importorskip("pydantic", minversion="2.0.0")
+
+
+def test_serializing_pydantic_edge_logs():
+    rlm = EdgeLogsModel(
+        dag_id="test_dag",
+        task_id="test_task",
+        run_id="test_run",
+        map_index=-1,
+        try_number=1,
+        log_chunk_time=timezone.utcnow(),
+        log_chunk_data="some logs captured",
+    )
+
+    pydantic_logs = EdgeLogs.model_validate(rlm)
+
+    json_string = pydantic_logs.model_dump_json()
+    print(json_string)
+
+    deserialized_model = EdgeLogs.model_validate_json(json_string)
+    assert deserialized_model.dag_id == rlm.dag_id
+    assert deserialized_model.try_number == rlm.try_number
+    assert deserialized_model.log_chunk_time == rlm.log_chunk_time
+    assert deserialized_model.log_chunk_data == rlm.log_chunk_data
diff --git a/tests/providers/edge/models/test_edge_worker.py 
b/tests/providers/edge/models/test_edge_worker.py
new file mode 100644
index 0000000000..9eca293baf
--- /dev/null
+++ b/tests/providers/edge/models/test_edge_worker.py
@@ -0,0 +1,65 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import pytest
+
+from airflow.providers.edge.models.edge_worker import (
+    EdgeWorker,
+    EdgeWorkerModel,
+    EdgeWorkerVersionException,
+)
+
+if TYPE_CHECKING:
+    from sqlalchemy.orm import Session
+
+pytestmark = pytest.mark.db_test
+
+
+class TestEdgeWorker:
+    @pytest.fixture(autouse=True)
+    def setup_test_cases(self, session: Session):
+        session.query(EdgeWorkerModel).delete()
+
+    def test_assert_version(self):
+        from airflow import __version__ as airflow_version
+        from airflow.providers.edge import __version__ as edge_provider_version
+
+        with pytest.raises(EdgeWorkerVersionException):
+            EdgeWorker.assert_version({})
+
+        with pytest.raises(EdgeWorkerVersionException):
+            EdgeWorker.assert_version({"airflow_version": airflow_version})
+
+        with pytest.raises(EdgeWorkerVersionException):
+            EdgeWorker.assert_version({"edge_provider_version": 
edge_provider_version})
+
+        with pytest.raises(EdgeWorkerVersionException):
+            EdgeWorker.assert_version(
+                {"airflow_version": "1.2.3", "edge_provider_version": 
edge_provider_version}
+            )
+
+        with pytest.raises(EdgeWorkerVersionException):
+            EdgeWorker.assert_version(
+                {"airflow_version": airflow_version, "edge_provider_version": 
"2023.10.07"}
+            )
+
+        EdgeWorker.assert_version(
+            {"airflow_version": airflow_version, "edge_provider_version": 
edge_provider_version}
+        )


Reply via email to