This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 55e419e95ab Remove AIP-44 from Job (#44493)
55e419e95ab is described below
commit 55e419e95ab027d161cef95571300af9b2c81a0d
Author: Jarek Potiuk <[email protected]>
AuthorDate: Sat Nov 30 03:19:32 2024 +0100
Remove AIP-44 from Job (#44493)
Part of #44436
---
airflow/jobs/job.py | 128 +++++----------------
.../providers/edge/worker_api/routes/rpc_api.py | 1 -
tests/jobs/test_base_job.py | 9 +-
3 files changed, 32 insertions(+), 106 deletions(-)
diff --git a/airflow/jobs/job.py b/airflow/jobs/job.py
index 6e802372d83..75a075efdfc 100644
--- a/airflow/jobs/job.py
+++ b/airflow/jobs/job.py
@@ -26,13 +26,11 @@ from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import backref, foreign, relationship
from sqlalchemy.orm.session import make_transient
-from airflow.api_internal.internal_api_call import internal_api_call
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.executors.executor_loader import ExecutorLoader
from airflow.listeners.listener import get_listener_manager
from airflow.models.base import ID_LEN, Base
-from airflow.serialization.pydantic.job import JobPydantic
from airflow.stats import Stats
from airflow.traces.tracer import Trace, add_span
from airflow.utils import timezone
@@ -40,8 +38,7 @@ from airflow.utils.helpers import convert_camel_to_snake
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.net import get_hostname
from airflow.utils.platform import getuser
-from airflow.utils.retries import retry_db_transaction
-from airflow.utils.session import NEW_SESSION, provide_session
+from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.utils.sqlalchemy import UtcDateTime
from airflow.utils.state import JobState
@@ -168,7 +165,10 @@ class Job(Base, LoggingMixin):
except Exception as e:
self.log.error("on_kill() method failed: %s", e)
- Job._kill(job_id=self.id, session=session)
+ job = session.scalar(select(Job).where(Job.id == self.id,
session=session).limit(1))
+ job.end_date = timezone.utcnow()
+ session.merge(job)
+ session.commit()
raise AirflowException("Job shut down externally.")
def on_kill(self):
@@ -201,7 +201,7 @@ class Job(Base, LoggingMixin):
try:
span.set_attribute("heartbeat", str(self.latest_heartbeat))
# This will cause it to load from the db
- self._merge_from(Job._fetch_from_db(self, session))
+ session.merge(self)
previous_heartbeat = self.latest_heartbeat
if self.state == JobState.RESTARTING:
@@ -217,17 +217,19 @@ class Job(Base, LoggingMixin):
if span.is_recording():
span.add_event(name="sleep", attributes={"sleep_for":
sleep_for})
sleep(sleep_for)
-
- job = Job._update_heartbeat(job=self, session=session)
- self._merge_from(job)
- time_since_last_heartbeat = (timezone.utcnow() -
previous_heartbeat).total_seconds()
- health_check_threshold_value =
health_check_threshold(self.job_type, self.heartrate)
- if time_since_last_heartbeat > health_check_threshold_value:
- self.log.info("Heartbeat recovered after %.2f seconds",
time_since_last_heartbeat)
- # At this point, the DB has updated.
- previous_heartbeat = self.latest_heartbeat
-
- heartbeat_callback(session)
+ # Update last heartbeat time
+ with create_session() as session:
+ # Make the session aware of this object
+ session.merge(self)
+ self.latest_heartbeat = timezone.utcnow()
+ session.commit()
+ time_since_last_heartbeat = (timezone.utcnow() -
previous_heartbeat).total_seconds()
+ health_check_threshold_value =
health_check_threshold(self.job_type, self.heartrate)
+ if time_since_last_heartbeat >
health_check_threshold_value:
+ self.log.info("Heartbeat recovered after %.2f
seconds", time_since_last_heartbeat)
+ # At this point, the DB has updated.
+ previous_heartbeat = self.latest_heartbeat
+ heartbeat_callback(session)
self.log.debug("[heartbeat]")
self.heartbeat_failed = False
except OperationalError:
@@ -260,36 +262,23 @@ class Job(Base, LoggingMixin):
Stats.incr(self.__class__.__name__.lower() + "_start", 1, 1)
self.state = JobState.RUNNING
self.start_date = timezone.utcnow()
- self._merge_from(Job._add_to_db(job=self, session=session))
+ session.add(self)
+ session.commit()
make_transient(self)
@provide_session
def complete_execution(self, session: Session = NEW_SESSION):
get_listener_manager().hook.before_stopping(component=self)
self.end_date = timezone.utcnow()
- Job._update_in_db(job=self, session=session)
+ session.merge(self)
+ session.commit()
Stats.incr(self.__class__.__name__.lower() + "_end", 1, 1)
@provide_session
- def most_recent_job(self, session: Session = NEW_SESSION) -> Job |
JobPydantic | None:
+ def most_recent_job(self, session: Session = NEW_SESSION) -> Job | None:
"""Return the most recent job of this type, if any, based on last
heartbeat received."""
return most_recent_job(self.job_type, session=session)
- def _merge_from(self, job: Job | JobPydantic | None):
- if job is None:
- self.log.error("Job is empty: %s", self.id)
- return
- self.id = job.id
- self.dag_id = job.dag_id
- self.state = job.state
- self.job_type = job.job_type
- self.start_date = job.start_date
- self.end_date = job.end_date
- self.latest_heartbeat = job.latest_heartbeat
- self.executor_class = job.executor_class
- self.hostname = job.hostname
- self.unixname = job.unixname
-
@staticmethod
def _heartrate(job_type: str) -> float:
if job_type == "TriggererJob":
@@ -312,74 +301,9 @@ class Job(Base, LoggingMixin):
and (timezone.utcnow() - latest_heartbeat).total_seconds() <
health_check_threshold_value
)
- @staticmethod
- @internal_api_call
- @provide_session
- def _kill(job_id: str, session: Session = NEW_SESSION) -> Job |
JobPydantic:
- job = session.scalar(select(Job).where(Job.id == job_id).limit(1))
- job.end_date = timezone.utcnow()
- session.merge(job)
- session.commit()
- return job
-
- @staticmethod
- @internal_api_call
- @provide_session
- @retry_db_transaction
- def _fetch_from_db(job: Job | JobPydantic, session: Session = NEW_SESSION)
-> Job | JobPydantic | None:
- if isinstance(job, Job):
- # not Internal API
- session.merge(job)
- return job
- # Internal API,
- return session.scalar(select(Job).where(Job.id == job.id).limit(1))
-
- @staticmethod
- @internal_api_call
- @provide_session
- def _add_to_db(job: Job | JobPydantic, session: Session = NEW_SESSION) ->
Job | JobPydantic:
- if isinstance(job, JobPydantic):
- orm_job = Job()
- orm_job._merge_from(job)
- else:
- orm_job = job
- session.add(orm_job)
- session.commit()
- return orm_job
-
- @staticmethod
- @internal_api_call
- @provide_session
- def _update_in_db(job: Job | JobPydantic, session: Session = NEW_SESSION):
- if isinstance(job, Job):
- # not Internal API
- session.merge(job)
- session.commit()
- # Internal API.
- orm_job: Job | None = session.scalar(select(Job).where(Job.id ==
job.id).limit(1))
- if orm_job is None:
- return
- orm_job._merge_from(job)
- session.merge(orm_job)
- session.commit()
-
- @staticmethod
- @internal_api_call
- @provide_session
- @retry_db_transaction
- def _update_heartbeat(job: Job | JobPydantic, session: Session =
NEW_SESSION) -> Job | JobPydantic:
- orm_job: Job | None = session.scalar(select(Job).where(Job.id ==
job.id).limit(1))
- if orm_job is None:
- return job
- orm_job.latest_heartbeat = timezone.utcnow()
- session.merge(orm_job)
- session.commit()
- return orm_job
-
-@internal_api_call
@provide_session
-def most_recent_job(job_type: str, session: Session = NEW_SESSION) -> Job |
JobPydantic | None:
+def most_recent_job(job_type: str, session: Session = NEW_SESSION) -> Job |
None:
"""
Return the most recent job of this type, if any, based on last heartbeat
received.
@@ -434,7 +358,7 @@ def execute_job(job: Job, execute_callable: Callable[[],
int | None]) -> int | N
which happens in the "complete_execution" step (which again can be
executed locally in case of
database operations or over the Internal API call.
- :param job: Job to execute - it can be either DB job or it's Pydantic
serialized version. It does
+ :param job: Job to execute - DB job. It does
not really matter, because except of running the heartbeat and state
setting,
the runner should not modify the job state.
diff --git a/providers/src/airflow/providers/edge/worker_api/routes/rpc_api.py
b/providers/src/airflow/providers/edge/worker_api/routes/rpc_api.py
index b3ceaa68700..aa5b30f5ab7 100644
--- a/providers/src/airflow/providers/edge/worker_api/routes/rpc_api.py
+++ b/providers/src/airflow/providers/edge/worker_api/routes/rpc_api.py
@@ -119,7 +119,6 @@ def _initialize_method_map() -> dict[str, Callable]:
expand_alias_to_assets,
FileTaskHandler._render_filename_db_access,
Job._add_to_db,
- Job._fetch_from_db,
Job._kill,
Job._update_heartbeat,
Job._update_in_db,
diff --git a/tests/jobs/test_base_job.py b/tests/jobs/test_base_job.py
index 4d5f3787ba6..63f5ef1910a 100644
--- a/tests/jobs/test_base_job.py
+++ b/tests/jobs/test_base_job.py
@@ -228,10 +228,13 @@ class TestJob:
job.latest_heartbeat = timezone.utcnow() -
datetime.timedelta(seconds=10)
assert job.is_alive() is False, "Completed jobs even with recent
heartbeat should not be alive"
- def test_heartbeat_failed(self, caplog):
+ @patch("airflow.jobs.job.create_session")
+ def test_heartbeat_failed(self, mock_create_session, caplog):
when = timezone.utcnow() - datetime.timedelta(seconds=60)
- mock_session = Mock(name="MockSession")
- mock_session.commit.side_effect = OperationalError("Force fail", {},
None)
+ with create_session() as session:
+ mock_session = Mock(spec_set=session, name="MockSession")
+ mock_create_session.return_value.__enter__.return_value =
mock_session
+ mock_session.commit.side_effect = OperationalError("Force fail",
{}, None)
job = Job(heartrate=10, state=State.RUNNING)
job.latest_heartbeat = when
with caplog.at_level(logging.ERROR):