This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 40419dd   Fix running tasks with default_impersonation config (#17229)
40419dd is described below

commit 40419dd371c7be53e6c8017b0c4d1bc7f75d0fb6
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Wed Jul 28 16:00:25 2021 +0100

     Fix running tasks with default_impersonation config (#17229)
    
    When default_impersonation is set in the configuration, airflow fails
     to run task due to PID mismatch between the recorded PID and the current 
PID
    
     This change fixes it by checking if task_runner.run_as_user is True and 
use the
     same way we check when ti.run_as_user is true to check the PID
---
 airflow/jobs/local_task_job.py    |  5 ++++-
 tests/jobs/test_local_task_job.py | 42 +++++++++++++++++++++++++++++++++++++++
 2 files changed, 46 insertions(+), 1 deletion(-)

diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py
index eb3aefc..4b29153 100644
--- a/airflow/jobs/local_task_job.py
+++ b/airflow/jobs/local_task_job.py
@@ -192,9 +192,12 @@ class LocalTaskJob(BaseJob):
                 )
                 raise AirflowException("Hostname of job runner does not match")
             current_pid = self.task_runner.process.pid
+
             same_process = ti.pid == current_pid
-            if ti.run_as_user:
+
+            if ti.run_as_user or self.task_runner.run_as_user:
                 same_process = psutil.Process(ti.pid).ppid() == current_pid
+
             if ti.pid is not None and not same_process:
                 self.log.warning("Recorded pid %s does not match " "the 
current pid %s", ti.pid, current_pid)
                 raise AirflowException("PID of job runner does not match")
diff --git a/tests/jobs/test_local_task_job.py 
b/tests/jobs/test_local_task_job.py
index ef22505..585d5ef 100644
--- a/tests/jobs/test_local_task_job.py
+++ b/tests/jobs/test_local_task_job.py
@@ -184,6 +184,48 @@ class TestLocalTaskJob:
         with pytest.raises(AirflowException, match='PID of job runner does not 
match'):
             job1.heartbeat_callback()
 
+    @conf_vars({('core', 'default_impersonation'): 'testuser'})
+    @mock.patch('airflow.jobs.local_task_job.psutil')
+    def test_localtaskjob_heartbeat_with_default_impersonation(self, 
psutil_mock, dag_maker):
+        session = settings.Session()
+        with dag_maker('test_localtaskjob_heartbeat'):
+            op1 = DummyOperator(task_id='op1')
+        dr = dag_maker.dag_run
+        ti = dr.get_task_instance(task_id=op1.task_id, session=session)
+        ti.state = State.RUNNING
+        ti.pid = 2
+        ti.hostname = get_hostname()
+        session.commit()
+
+        job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, 
executor=SequentialExecutor())
+        ti.task = op1
+        ti.refresh_from_task(op1)
+        job1.task_runner = StandardTaskRunner(job1)
+        job1.task_runner.process = mock.Mock()
+        job1.task_runner.process.pid = 2
+        # Here, ti.pid is 2, the parent process of ti.pid is a mock(different).
+        # And task_runner process is 2. Should fail
+        with pytest.raises(AirflowException, match='PID of job runner does not 
match'):
+            job1.heartbeat_callback()
+
+        job1.task_runner.process.pid = 1
+        # We make the parent process of ti.pid to equal the task_runner 
process id
+        psutil_mock.Process.return_value.ppid.return_value = 1
+        ti.state = State.RUNNING
+        ti.pid = 2
+        # The task_runner process id is 1, same as the parent process of ti.pid
+        # as seen above
+        assert job1.task_runner.run_as_user == 'testuser'
+        session.merge(ti)
+        session.commit()
+        job1.heartbeat_callback(session=None)
+
+        # Here the task_runner process id is changed to 2
+        # while parent process of ti.pid is kept at 1, which is different
+        job1.task_runner.process.pid = 2
+        with pytest.raises(AirflowException, match='PID of job runner does not 
match'):
+            job1.heartbeat_callback()
+
     def test_heartbeat_failed_fast(self):
         """
         Test that task heartbeat will sleep when it fails fast

Reply via email to