This is an automated email from the ASF dual-hosted git repository.
jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 788b9c486b Add DB models for Edge Provider (#42047)
788b9c486b is described below
commit 788b9c486bf9e42fb4b10a30edef7f536bb873d6
Author: Jens Scheffler <[email protected]>
AuthorDate: Mon Sep 16 21:55:56 2024 +0200
Add DB models for Edge Provider (#42047)
---
airflow/providers/edge/example_dags/__init__.py | 16 ++
.../edge/example_dags/integration_test.py | 139 ++++++++++++++
airflow/providers/edge/models/__init__.py | 16 ++
airflow/providers/edge/models/edge_job.py | 185 +++++++++++++++++++
airflow/providers/edge/models/edge_logs.py | 153 ++++++++++++++++
airflow/providers/edge/models/edge_worker.py | 203 +++++++++++++++++++++
tests/providers/edge/models/__init__.py | 17 ++
tests/providers/edge/models/test_edge_job.py | 99 ++++++++++
tests/providers/edge/models/test_edge_logs.py | 49 +++++
tests/providers/edge/models/test_edge_worker.py | 65 +++++++
10 files changed, 942 insertions(+)
diff --git a/airflow/providers/edge/example_dags/__init__.py
b/airflow/providers/edge/example_dags/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/airflow/providers/edge/example_dags/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/edge/example_dags/integration_test.py
b/airflow/providers/edge/example_dags/integration_test.py
new file mode 100644
index 0000000000..d6074abd30
--- /dev/null
+++ b/airflow/providers/edge/example_dags/integration_test.py
@@ -0,0 +1,139 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+In this DAG all critical functions as integration test are contained.
+
+The DAG should work in all standard setups without error.
+"""
+
+from __future__ import annotations
+
+from datetime import datetime
+from time import sleep
+
+from airflow.decorators import task, task_group
+from airflow.exceptions import AirflowNotFoundException
+from airflow.hooks.base import BaseHook
+from airflow.models.dag import DAG
+from airflow.models.param import Param
+from airflow.models.variable import Variable
+from airflow.operators.bash import BashOperator
+from airflow.operators.empty import EmptyOperator
+from airflow.operators.python import PythonOperator
+
+with DAG(
+ dag_id="integration_test",
+ dag_display_name="Integration Test",
+ description=__doc__.partition(".")[0],
+ doc_md=__doc__,
+ schedule=None,
+ start_date=datetime(2024, 7, 1),
+ tags=["example", "params", "integration test"],
+ params={
+ "mapping_count": Param(
+ 4,
+ type="integer",
+ title="Mapping Count",
+ description="Amount of tasks that should be mapped",
+ ),
+ },
+) as dag:
+
+ @task
+ def my_setup():
+ print("Assume this is a setup task")
+
+ @task
+ def mapping_from_params(**context) -> list[int]:
+ mapping_count: int = context["params"]["mapping_count"]
+ return list(range(1, mapping_count + 1))
+
+ @task
+ def add_one(x: int):
+ return x + 1
+
+ @task
+ def sum_it(values):
+ total = sum(values)
+ print(f"Total was {total}")
+
+ @task_group(prefix_group_id=False)
+ def mapping_task_group():
+ added_values = add_one.expand(x=mapping_from_params())
+ sum_it(added_values)
+
+ @task.branch
+ def branching():
+ return ["bash", "virtualenv", "variable", "connection",
"classic_bash", "classic_python"]
+
+ @task.bash
+ def bash():
+ return "echo hello world"
+
+ @task.virtualenv(requirements="numpy")
+ def virtualenv():
+ import numpy
+
+ print(f"Welcome to virtualenv with numpy version {numpy.__version__}.")
+
+ @task
+ def variable():
+ Variable.set("integration_test_key", "value")
+ assert Variable.get("integration_test_key") == "value" # noqa: S101
+ Variable.delete("integration_test_key")
+
+ @task
+ def connection():
+ try:
+ conn = BaseHook.get_connection("integration_test")
+ print(f"Got connection {conn}")
+ except AirflowNotFoundException:
+ print("Connection not found... but also OK.")
+
+ @task_group(prefix_group_id=False)
+ def standard_tasks_group():
+ classic_bash = BashOperator(
+ task_id="classic_bash", bash_command="echo Parameter is {{
params.mapping_count }}"
+ )
+
+ empty = EmptyOperator(task_id="not_executed")
+
+ def python_call():
+ print("Hello world")
+
+ classic_py = PythonOperator(task_id="classic_python",
python_callable=python_call)
+
+ branching() >> [bash(), virtualenv(), variable(), connection(),
classic_bash, classic_py, empty]
+
+ @task
+ def long_running():
+ print("This task runs for 15 minutes")
+ for i in range(15):
+ sleep(60)
+ print(f"Running for {i + 1} minutes now.")
+ print("Long running task completed.")
+
+ @task
+ def my_teardown():
+ print("Assume this is a teardown task")
+
+ (
+ my_setup().as_setup()
+ >> [mapping_task_group(), standard_tasks_group(), long_running()]
+ >> my_teardown().as_teardown()
+ )
diff --git a/airflow/providers/edge/models/__init__.py
b/airflow/providers/edge/models/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/airflow/providers/edge/models/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/edge/models/edge_job.py
b/airflow/providers/edge/models/edge_job.py
new file mode 100644
index 0000000000..b6e316e1f7
--- /dev/null
+++ b/airflow/providers/edge/models/edge_job.py
@@ -0,0 +1,185 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from ast import literal_eval
+from datetime import datetime
+from typing import TYPE_CHECKING, List, Optional
+
+from pydantic import BaseModel, ConfigDict
+from sqlalchemy import (
+ Column,
+ Index,
+ Integer,
+ String,
+ select,
+ text,
+)
+
+from airflow.api_internal.internal_api_call import internal_api_call
+from airflow.models.base import Base, StringID
+from airflow.models.taskinstancekey import TaskInstanceKey
+from airflow.serialization.serialized_objects import
add_pydantic_class_type_mapping
+from airflow.utils import timezone
+from airflow.utils.log.logging_mixin import LoggingMixin
+from airflow.utils.session import NEW_SESSION, provide_session
+from airflow.utils.sqlalchemy import UtcDateTime, with_row_locks
+from airflow.utils.state import TaskInstanceState
+
+if TYPE_CHECKING:
+ from sqlalchemy.orm.session import Session
+
+
+class EdgeJobModel(Base, LoggingMixin):
+ """
+ A job which is queued, waiting or running on a Edge Worker.
+
+ Each tuple in the database represents and describes the state of one job.
+ """
+
+ __tablename__ = "edge_job"
+ dag_id = Column(StringID(), primary_key=True, nullable=False)
+ task_id = Column(StringID(), primary_key=True, nullable=False)
+ run_id = Column(StringID(), primary_key=True, nullable=False)
+ map_index = Column(Integer, primary_key=True, nullable=False,
server_default=text("-1"))
+ try_number = Column(Integer, primary_key=True, default=0)
+ state = Column(String(20))
+ queue = Column(String(256))
+ command = Column(String(1000))
+ queued_dttm = Column(UtcDateTime)
+ edge_worker = Column(String(64))
+ last_update = Column(UtcDateTime)
+
+ def __init__(
+ self,
+ dag_id: str,
+ task_id: str,
+ run_id: str,
+ map_index: int,
+ try_number: int,
+ state: str,
+ queue: str,
+ command: str,
+ queued_dttm: datetime | None = None,
+ edge_worker: str | None = None,
+ last_update: datetime | None = None,
+ ):
+ self.dag_id = dag_id
+ self.task_id = task_id
+ self.run_id = run_id
+ self.map_index = map_index
+ self.try_number = try_number
+ self.state = state
+ self.queue = queue
+ self.command = command
+ self.queued_dttm = queued_dttm or timezone.utcnow()
+ self.edge_worker = edge_worker
+ self.last_update = last_update
+ super().__init__()
+
+ __table_args__ = (Index("rj_order", state, queued_dttm, queue),)
+
+ @property
+ def key(self):
+ return TaskInstanceKey(self.dag_id, self.task_id, self.run_id,
self.try_number, self.map_index)
+
+ @property
+ def last_update_t(self) -> float:
+ return self.last_update.timestamp()
+
+
+class EdgeJob(BaseModel, LoggingMixin):
+ """Accessor for edge jobs as logical model."""
+
+ dag_id: str
+ task_id: str
+ run_id: str
+ map_index: int
+ try_number: int
+ state: TaskInstanceState
+ queue: str
+ command: List[str] # noqa: UP006 - prevent Sphinx failing
+ queued_dttm: datetime
+ edge_worker: Optional[str] # noqa: UP007 - prevent Sphinx failing
+ last_update: Optional[datetime] # noqa: UP007 - prevent Sphinx failing
+ model_config = ConfigDict(from_attributes=True,
arbitrary_types_allowed=True)
+
+ @property
+ def key(self) -> TaskInstanceKey:
+ return TaskInstanceKey(self.dag_id, self.task_id, self.run_id,
self.try_number, self.map_index)
+
+ @staticmethod
+ @internal_api_call
+ @provide_session
+ def reserve_task(
+ worker_name: str, queues: list[str] | None = None, session: Session =
NEW_SESSION
+ ) -> EdgeJob | None:
+ query = (
+ select(EdgeJobModel)
+ .where(EdgeJobModel.state == TaskInstanceState.QUEUED)
+ .order_by(EdgeJobModel.queued_dttm)
+ )
+ if queues:
+ query = query.where(EdgeJobModel.queue.in_(queues))
+ query = query.limit(1)
+ query = with_row_locks(query, of=EdgeJobModel, session=session,
skip_locked=True)
+ job: EdgeJobModel = session.scalar(query)
+ if not job:
+ return None
+ job.state = TaskInstanceState.RUNNING
+ job.edge_worker = worker_name
+ job.last_update = timezone.utcnow()
+ session.commit()
+ return EdgeJob(
+ dag_id=job.dag_id,
+ task_id=job.task_id,
+ run_id=job.run_id,
+ map_index=job.map_index,
+ try_number=job.try_number,
+ state=job.state,
+ queue=job.queue,
+ command=literal_eval(job.command),
+ queued_dttm=job.queued_dttm,
+ edge_worker=job.edge_worker,
+ last_update=job.last_update,
+ )
+
+ @staticmethod
+ @internal_api_call
+ @provide_session
+ def set_state(task: TaskInstanceKey | tuple, state: TaskInstanceState,
session: Session = NEW_SESSION):
+ if isinstance(task, tuple):
+ task = TaskInstanceKey(*task)
+ query = select(EdgeJobModel).where(
+ EdgeJobModel.dag_id == task.dag_id,
+ EdgeJobModel.task_id == task.task_id,
+ EdgeJobModel.run_id == task.run_id,
+ EdgeJobModel.map_index == task.map_index,
+ EdgeJobModel.try_number == task.try_number,
+ )
+ job: EdgeJobModel = session.scalar(query)
+ job.state = state
+ job.last_update = timezone.utcnow()
+ session.commit()
+
+ def __hash__(self):
+ return
f"{self.dag_id}|{self.task_id}|{self.run_id}|{self.map_index}|{self.try_number}".__hash__()
+
+
+EdgeJob.model_rebuild()
+
+add_pydantic_class_type_mapping("edge_job", EdgeJobModel, EdgeJob)
diff --git a/airflow/providers/edge/models/edge_logs.py
b/airflow/providers/edge/models/edge_logs.py
new file mode 100644
index 0000000000..29625f5be7
--- /dev/null
+++ b/airflow/providers/edge/models/edge_logs.py
@@ -0,0 +1,153 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from datetime import datetime
+from functools import lru_cache
+from pathlib import Path
+from typing import TYPE_CHECKING
+
+from pydantic import BaseModel, ConfigDict
+from sqlalchemy import (
+ Column,
+ Integer,
+ Text,
+ text,
+)
+from sqlalchemy.dialects.mysql import MEDIUMTEXT
+
+from airflow.api_internal.internal_api_call import internal_api_call
+from airflow.configuration import conf
+from airflow.models.base import Base, StringID
+from airflow.models.taskinstance import TaskInstance
+from airflow.models.taskinstancekey import TaskInstanceKey
+from airflow.serialization.serialized_objects import
add_pydantic_class_type_mapping
+from airflow.utils.log.logging_mixin import LoggingMixin
+from airflow.utils.session import NEW_SESSION, provide_session
+from airflow.utils.sqlalchemy import UtcDateTime
+
+if TYPE_CHECKING:
+ from sqlalchemy.orm.session import Session
+
+
+class EdgeLogsModel(Base, LoggingMixin):
+ """
+ Temporary collected logs from a Edge Worker while job runs on remote site.
+
+ As the Edge Worker in most cases has a local file system and the web UI no
access
+ to read files from remote site, Edge Workers will send incremental chunks
of logs
+ of running jobs to the central site. As log storage backends in most cloud
cases can not
+ append logs, the table is used as buffer to receive. Upon task completion
logs can be
+ flushed to task log handler.
+
+ Log data therefore is collected in chunks and is only temporary.
+ """
+
+ __tablename__ = "edge_logs"
+ dag_id = Column(StringID(), primary_key=True, nullable=False)
+ task_id = Column(StringID(), primary_key=True, nullable=False)
+ run_id = Column(StringID(), primary_key=True, nullable=False)
+ map_index = Column(Integer, primary_key=True, nullable=False,
server_default=text("-1"))
+ try_number = Column(Integer, primary_key=True, default=0)
+ log_chunk_time = Column(UtcDateTime, primary_key=True, nullable=False)
+ log_chunk_data = Column(Text().with_variant(MEDIUMTEXT(), "mysql"),
nullable=False)
+
+ def __init__(
+ self,
+ dag_id: str,
+ task_id: str,
+ run_id: str,
+ map_index: int,
+ try_number: int,
+ log_chunk_time: datetime,
+ log_chunk_data: str,
+ ):
+ self.dag_id = dag_id
+ self.task_id = task_id
+ self.run_id = run_id
+ self.map_index = map_index
+ self.try_number = try_number
+ self.log_chunk_time = log_chunk_time
+ self.log_chunk_data = log_chunk_data
+ super().__init__()
+
+
+class EdgeLogs(BaseModel, LoggingMixin):
+ """Accessor for Edge Worker instances as logical model."""
+
+ dag_id: str
+ task_id: str
+ run_id: str
+ map_index: int
+ try_number: int
+ log_chunk_time: datetime
+ log_chunk_data: str
+ model_config = ConfigDict(from_attributes=True,
arbitrary_types_allowed=True)
+
+ @staticmethod
+ @internal_api_call
+ @provide_session
+ def push_logs(
+ task: TaskInstanceKey | tuple,
+ log_chunk_time: datetime,
+ log_chunk_data: str,
+ session: Session = NEW_SESSION,
+ ) -> None:
+ """Push an incremental log chunk from Edge Worker to central site."""
+ if isinstance(task, tuple):
+ task = TaskInstanceKey(*task)
+ log_chunk = EdgeLogsModel(
+ dag_id=task.dag_id,
+ task_id=task.task_id,
+ run_id=task.run_id,
+ map_index=task.map_index,
+ try_number=task.try_number,
+ log_chunk_time=log_chunk_time,
+ log_chunk_data=log_chunk_data,
+ )
+ session.add(log_chunk)
+ # Write logs to local file to make them accessible
+ logfile_path = EdgeLogs.logfile_path(task)
+ if not logfile_path.exists():
+ new_folder_permissions = int(
+ conf.get("logging",
"file_task_handler_new_folder_permissions", fallback="0o775"), 8
+ )
+ logfile_path.parent.mkdir(parents=True, exist_ok=True,
mode=new_folder_permissions)
+ with logfile_path.open("a") as logfile:
+ logfile.write(log_chunk_data)
+
+ @staticmethod
+ @lru_cache
+ def logfile_path(task: TaskInstanceKey) -> Path:
+ """Elaborate the path and filename to expect from task execution."""
+ from airflow.utils.log.file_task_handler import FileTaskHandler
+
+ ti = TaskInstance.get_task_instance(
+ dag_id=task.dag_id,
+ run_id=task.run_id,
+ task_id=task.task_id,
+ map_index=task.map_index,
+ )
+ if TYPE_CHECKING:
+ assert ti
+ base_log_folder = conf.get("logging", "base_log_folder", fallback="NOT
AVAILABLE")
+ return Path(base_log_folder,
FileTaskHandler(base_log_folder)._render_filename(ti, task.try_number))
+
+
+EdgeLogs.model_rebuild()
+
+add_pydantic_class_type_mapping("edge_logs", EdgeLogsModel, EdgeLogs)
diff --git a/airflow/providers/edge/models/edge_worker.py
b/airflow/providers/edge/models/edge_worker.py
new file mode 100644
index 0000000000..193795e37d
--- /dev/null
+++ b/airflow/providers/edge/models/edge_worker.py
@@ -0,0 +1,203 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import json
+from datetime import datetime
+from enum import Enum
+from typing import TYPE_CHECKING, List, Optional
+
+from pydantic import BaseModel, ConfigDict
+from sqlalchemy import (
+ Column,
+ Integer,
+ String,
+ select,
+)
+
+from airflow.api_internal.internal_api_call import internal_api_call
+from airflow.exceptions import AirflowException
+from airflow.models.base import Base
+from airflow.serialization.serialized_objects import
add_pydantic_class_type_mapping
+from airflow.utils import timezone
+from airflow.utils.log.logging_mixin import LoggingMixin
+from airflow.utils.session import NEW_SESSION, provide_session
+from airflow.utils.sqlalchemy import UtcDateTime
+
+if TYPE_CHECKING:
+ from sqlalchemy.orm.session import Session
+
+
+class EdgeWorkerVersionException(AirflowException):
+ """Signal a version mismatch between core and Edge Site."""
+
+ pass
+
+
+class EdgeWorkerState(str, Enum):
+ """Status of a Edge Worker instance."""
+
+ STARTING = "starting"
+ """Edge Worker is in initialization."""
+ RUNNING = "running"
+ """Edge Worker is actively running a task."""
+ IDLE = "idle"
+ """Edge Worker is active and waiting for a task."""
+ TERMINATING = "terminating"
+ """Edge Worker is completing work and stopping."""
+ OFFLINE = "offline"
+ """Edge Worker was show down."""
+ UNKNOWN = "unknown"
+ """No heartbeat signal from worker for some time, Edge Worker probably
down."""
+
+
+class EdgeWorkerModel(Base, LoggingMixin):
+ """A Edge Worker instance which reports the state and health."""
+
+ __tablename__ = "edge_worker"
+ worker_name = Column(String(64), primary_key=True, nullable=False)
+ state = Column(String(20))
+ queues = Column(String(256))
+ first_online = Column(UtcDateTime)
+ last_update = Column(UtcDateTime)
+ jobs_active = Column(Integer, default=0)
+ jobs_taken = Column(Integer, default=0)
+ jobs_success = Column(Integer, default=0)
+ jobs_failed = Column(Integer, default=0)
+ sysinfo = Column(String(256))
+
+ def __init__(
+ self,
+ worker_name: str,
+ state: str,
+ queues: list[str] | None,
+ first_online: datetime | None = None,
+ last_update: datetime | None = None,
+ ):
+ self.worker_name = worker_name
+ self.state = state
+ self.queues = ", ".join(queues) if queues else None
+ self.first_online = first_online or timezone.utcnow()
+ self.last_update = last_update
+ super().__init__()
+
+ @property
+ def sysinfo_json(self) -> dict:
+ return json.loads(self.sysinfo) if self.sysinfo else None
+
+
+class EdgeWorker(BaseModel, LoggingMixin):
+ """Accessor for Edge Worker instances as logical model."""
+
+ worker_name: str
+ state: EdgeWorkerState
+ queues: Optional[List[str]] # noqa: UP006,UP007 - prevent Sphinx failing
+ first_online: datetime
+ last_update: Optional[datetime] = None # noqa: UP007 - prevent Sphinx
failing
+ jobs_active: int
+ jobs_taken: int
+ jobs_success: int
+ jobs_failed: int
+ sysinfo: str
+ model_config = ConfigDict(from_attributes=True,
arbitrary_types_allowed=True)
+
+ @staticmethod
+ def assert_version(sysinfo: dict[str, str]) -> None:
+ """Check if the Edge Worker version matches the central API site."""
+ from airflow import __version__ as airflow_version
+ from airflow.providers.edge import __version__ as edge_provider_version
+
+ # Note: In future, more stable versions we might be more liberate, for
the
+ # moment we require exact version match for Edge Worker and core
version
+ if "airflow_version" in sysinfo:
+ airflow_on_worker = sysinfo["airflow_version"]
+ if airflow_on_worker != airflow_version:
+ raise EdgeWorkerVersionException(
+ f"Edge Worker runs on Airflow {airflow_on_worker} "
+ f"and the core runs on {airflow_version}. Rejecting access
due to difference."
+ )
+ else:
+ raise EdgeWorkerVersionException("Edge Worker does not specify the
version it is running on.")
+
+ if "edge_provider_version" in sysinfo:
+ provider_on_worker = sysinfo["edge_provider_version"]
+ if provider_on_worker != edge_provider_version:
+ raise EdgeWorkerVersionException(
+ f"Edge Worker runs on Edge Provider {provider_on_worker} "
+ f"and the core runs on {edge_provider_version}. Rejecting
access due to difference."
+ )
+ else:
+ raise EdgeWorkerVersionException(
+ "Edge Worker does not specify the provider version it is
running on."
+ )
+
+ @staticmethod
+ @internal_api_call
+ @provide_session
+ def register_worker(
+ worker_name: str,
+ state: EdgeWorkerState,
+ queues: list[str] | None,
+ sysinfo: dict[str, str],
+ session: Session = NEW_SESSION,
+ ) -> EdgeWorker:
+ EdgeWorker.assert_version(sysinfo)
+ query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name ==
worker_name)
+ worker: EdgeWorkerModel = session.scalar(query)
+ if not worker:
+ worker = EdgeWorkerModel(worker_name=worker_name, state=state,
queues=queues)
+ worker.state = state
+ worker.queues = queues
+ worker.sysinfo = json.dumps(sysinfo)
+ worker.last_update = timezone.utcnow()
+ session.add(worker)
+ return EdgeWorker(
+ worker_name=worker_name,
+ state=state,
+ queues=worker.queues,
+ first_online=worker.first_online,
+ last_update=worker.last_update,
+ jobs_active=worker.jobs_active or 0,
+ jobs_taken=worker.jobs_taken or 0,
+ jobs_success=worker.jobs_success or 0,
+ jobs_failed=worker.jobs_failed or 0,
+ sysinfo=worker.sysinfo or "{}",
+ )
+
+ @staticmethod
+ @internal_api_call
+ @provide_session
+ def set_state(
+ worker_name: str,
+ state: EdgeWorkerState,
+ jobs_active: int,
+ sysinfo: dict[str, str],
+ session: Session = NEW_SESSION,
+ ):
+ EdgeWorker.assert_version(sysinfo)
+ query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name ==
worker_name)
+ worker: EdgeWorkerModel = session.scalar(query)
+ worker.state = state
+ worker.jobs_active = jobs_active
+ worker.sysinfo = json.dumps(sysinfo)
+ worker.last_update = timezone.utcnow()
+ session.commit()
+
+
+EdgeWorker.model_rebuild()
+
+add_pydantic_class_type_mapping("edge_worker", EdgeWorkerModel, EdgeWorker)
diff --git a/tests/providers/edge/models/__init__.py
b/tests/providers/edge/models/__init__.py
new file mode 100644
index 0000000000..217e5db960
--- /dev/null
+++ b/tests/providers/edge/models/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/edge/models/test_edge_job.py
b/tests/providers/edge/models/test_edge_job.py
new file mode 100644
index 0000000000..91dd897deb
--- /dev/null
+++ b/tests/providers/edge/models/test_edge_job.py
@@ -0,0 +1,99 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import pytest
+
+from airflow.providers.edge.models.edge_job import EdgeJob, EdgeJobModel
+from airflow.utils import timezone
+from airflow.utils.state import TaskInstanceState
+
+if TYPE_CHECKING:
+ from sqlalchemy.orm import Session
+
+pytestmark = pytest.mark.db_test
+pytest.importorskip("pydantic", minversion="2.0.0")
+
+
+class TestEdgeJob:
+ @pytest.fixture(autouse=True)
+ def setup_test_cases(self, session: Session):
+ session.query(EdgeJobModel).delete()
+
+ def test_reserve_task_no_job(self):
+ job = EdgeJob.reserve_task("worker")
+ assert job is None
+
+ def test_reserve_task_has_one(self, session: Session):
+ rjm = EdgeJobModel(
+ dag_id="test_dag",
+ task_id="test_task",
+ run_id="test_run",
+ map_index=-1,
+ try_number=1,
+ state=TaskInstanceState.QUEUED,
+ queue="default",
+ command=str(["hello", "world"]),
+ queued_dttm=timezone.utcnow(),
+ )
+ session.add(rjm)
+ session.commit()
+
+ job = EdgeJob.reserve_task("worker")
+ assert job
+ assert job.edge_worker == "worker"
+ assert job.queue == "default"
+ assert job.dag_id == "test_dag"
+ assert job.task_id == "test_task"
+ assert job.run_id == "test_run"
+
+ jobs: list[EdgeJobModel] = session.query(EdgeJobModel).all()
+ assert len(jobs) == 1
+ assert jobs[0].state == TaskInstanceState.RUNNING
+ assert jobs[0].edge_worker == "worker"
+ assert jobs[0].queue == "default"
+ assert jobs[0].dag_id == "test_dag"
+ assert jobs[0].task_id == "test_task"
+ assert jobs[0].run_id == "test_run"
+
+ def test_set_state(self, session: Session):
+ rjm = EdgeJobModel(
+ dag_id="test_dag",
+ task_id="test_task",
+ run_id="test_run",
+ map_index=-1,
+ try_number=1,
+ state=TaskInstanceState.RUNNING,
+ queue="default",
+ command=str(["hello", "world"]),
+ queued_dttm=timezone.utcnow(),
+ )
+ session.add(rjm)
+ session.commit()
+
+ EdgeJob.set_state(rjm.key, TaskInstanceState.FAILED)
+
+ jobs: list[EdgeJobModel] = session.query(EdgeJobModel).all()
+ assert len(jobs) == 1
+ assert jobs[0].state == TaskInstanceState.FAILED
+ assert jobs[0].last_update
+ assert jobs[0].queue == "default"
+ assert jobs[0].dag_id == "test_dag"
+ assert jobs[0].task_id == "test_task"
+ assert jobs[0].run_id == "test_run"
diff --git a/tests/providers/edge/models/test_edge_logs.py
b/tests/providers/edge/models/test_edge_logs.py
new file mode 100644
index 0000000000..0bb3307e0c
--- /dev/null
+++ b/tests/providers/edge/models/test_edge_logs.py
@@ -0,0 +1,49 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pytest
+
+from airflow.providers.edge.models.edge_logs import EdgeLogs, EdgeLogsModel
+from airflow.utils import timezone
+
+pytestmark = pytest.mark.db_test
+
+pytest.importorskip("pydantic", minversion="2.0.0")
+
+
+def test_serializing_pydantic_edge_logs():
+ rlm = EdgeLogsModel(
+ dag_id="test_dag",
+ task_id="test_task",
+ run_id="test_run",
+ map_index=-1,
+ try_number=1,
+ log_chunk_time=timezone.utcnow(),
+ log_chunk_data="some logs captured",
+ )
+
+ pydantic_logs = EdgeLogs.model_validate(rlm)
+
+ json_string = pydantic_logs.model_dump_json()
+ print(json_string)
+
+ deserialized_model = EdgeLogs.model_validate_json(json_string)
+ assert deserialized_model.dag_id == rlm.dag_id
+ assert deserialized_model.try_number == rlm.try_number
+ assert deserialized_model.log_chunk_time == rlm.log_chunk_time
+ assert deserialized_model.log_chunk_data == rlm.log_chunk_data
diff --git a/tests/providers/edge/models/test_edge_worker.py
b/tests/providers/edge/models/test_edge_worker.py
new file mode 100644
index 0000000000..9eca293baf
--- /dev/null
+++ b/tests/providers/edge/models/test_edge_worker.py
@@ -0,0 +1,65 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import pytest
+
+from airflow.providers.edge.models.edge_worker import (
+ EdgeWorker,
+ EdgeWorkerModel,
+ EdgeWorkerVersionException,
+)
+
+if TYPE_CHECKING:
+ from sqlalchemy.orm import Session
+
+pytestmark = pytest.mark.db_test
+
+
+class TestEdgeWorker:
+ @pytest.fixture(autouse=True)
+ def setup_test_cases(self, session: Session):
+ session.query(EdgeWorkerModel).delete()
+
+ def test_assert_version(self):
+ from airflow import __version__ as airflow_version
+ from airflow.providers.edge import __version__ as edge_provider_version
+
+ with pytest.raises(EdgeWorkerVersionException):
+ EdgeWorker.assert_version({})
+
+ with pytest.raises(EdgeWorkerVersionException):
+ EdgeWorker.assert_version({"airflow_version": airflow_version})
+
+ with pytest.raises(EdgeWorkerVersionException):
+ EdgeWorker.assert_version({"edge_provider_version":
edge_provider_version})
+
+ with pytest.raises(EdgeWorkerVersionException):
+ EdgeWorker.assert_version(
+ {"airflow_version": "1.2.3", "edge_provider_version":
edge_provider_version}
+ )
+
+ with pytest.raises(EdgeWorkerVersionException):
+ EdgeWorker.assert_version(
+ {"airflow_version": airflow_version, "edge_provider_version":
"2023.10.07"}
+ )
+
+ EdgeWorker.assert_version(
+ {"airflow_version": airflow_version, "edge_provider_version":
edge_provider_version}
+ )