You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/07/11 16:13:33 UTC

[airflow] branch main updated: Fix pid check (#24636)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 26c9768e44 Fix pid check (#24636)
26c9768e44 is described below

commit 26c9768e44315e08c298776e0d7fb400b442bf96
Author: Kevin Crouse <ke...@rutgers.edu>
AuthorDate: Mon Jul 11 12:13:16 2022 -0400

    Fix pid check (#24636)
---
 airflow/jobs/local_task_job.py    |  8 +++++++-
 tests/jobs/test_local_task_job.py | 30 ++++++++++++++++++++++++++++++
 2 files changed, 37 insertions(+), 1 deletion(-)

diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py
index 5711342e04..b73c8992d8 100644
--- a/airflow/jobs/local_task_job.py
+++ b/airflow/jobs/local_task_job.py
@@ -195,7 +195,13 @@ class LocalTaskJob(BaseJob):
             recorded_pid = ti.pid
             same_process = recorded_pid == current_pid
 
-            if ti.run_as_user or self.task_runner.run_as_user:
+            if recorded_pid is not None and (ti.run_as_user or self.task_runner.run_as_user):
+                # when running as another user, compare the task runner pid to the parent of
+                # the recorded pid because user delegation becomes an extra process level.
+                # However, if recorded_pid is None, pass that through as it signals the task
+                # runner process has already completed and been cleared out. `psutil.Process`
+                # uses the current process if the parameter is None, which is not what is intended
+                # for comparison.
                 recorded_pid = psutil.Process(ti.pid).ppid()
                 same_process = recorded_pid == current_pid
 
diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py
index 362bba1828..74348f8f43 100644
--- a/tests/jobs/test_local_task_job.py
+++ b/tests/jobs/test_local_task_job.py
@@ -154,6 +154,16 @@ class TestLocalTaskJob:
         with pytest.raises(AirflowException):
             job1.heartbeat_callback()
 
+        # Now, set the ti.pid to None and test that no error
+        # is raised.
+        ti.pid = None
+        session.merge(ti)
+        session.commit()
+        assert ti.pid != job1.task_runner.process.pid
+        assert not ti.run_as_user
+        assert not job1.task_runner.run_as_user
+        job1.heartbeat_callback()
+
     @mock.patch('subprocess.check_call')
     @mock.patch('airflow.jobs.local_task_job.psutil')
     def test_localtaskjob_heartbeat_with_run_as_user(self, psutil_mock, _, dag_maker):
@@ -196,6 +206,16 @@ class TestLocalTaskJob:
         with pytest.raises(AirflowException, match='PID of job runner does not match'):
             job1.heartbeat_callback()
 
+        # Here we set the ti.pid to None and test that no error is
+        # raised
+        ti.pid = None
+        session.merge(ti)
+        session.commit()
+        assert ti.run_as_user
+        assert job1.task_runner.run_as_user == ti.run_as_user
+        assert ti.pid != job1.task_runner.process.pid
+        job1.heartbeat_callback()
+
     @conf_vars({('core', 'default_impersonation'): 'testuser'})
     @mock.patch('subprocess.check_call')
     @mock.patch('airflow.jobs.local_task_job.psutil')
@@ -239,6 +259,16 @@ class TestLocalTaskJob:
         with pytest.raises(AirflowException, match='PID of job runner does not match'):
             job1.heartbeat_callback()
 
+        # Now, set the ti.pid to None and test that no error
+        # is raised.
+        ti.pid = None
+        session.merge(ti)
+        session.commit()
+        assert job1.task_runner.run_as_user == 'testuser'
+        assert ti.run_as_user is None
+        assert ti.pid != job1.task_runner.process.pid
+        job1.heartbeat_callback()
+
     def test_heartbeat_failed_fast(self):
         """
         Test that task heartbeat will sleep when it fails fast