You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2021/04/07 22:20:53 UTC

[airflow] branch master updated: Display explicit error in case UID has no actual username (#15212)

This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 3e9e954  Display explicit error in case UID has no actual username (#15212)
3e9e954 is described below

commit 3e9e954d9ec5236cbbc6da2091b38e69c1b4c0c0
Author: Andrew Godwin <an...@astronomer.io>
AuthorDate: Wed Apr 7 16:20:38 2021 -0600

    Display explicit error in case UID has no actual username (#15212)
    
    Fixes #9963 : Don't require a current username
    
    Previously, we used getpass.getuser() with no fallback, which errors out
    if there is no username specified for the current UID (which happens a
    lot more in environments like Docker & Kubernetes). This updates most
    calls to use our own copy which has a fallback to return the UID as a
    string if there is no username.
---
 airflow/cli/commands/info_command.py               |  4 ++--
 airflow/jobs/base_job.py                           |  4 ++--
 airflow/models/taskinstance.py                     |  4 ++--
 airflow/providers/microsoft/winrm/hooks/winrm.py   |  8 ++++++--
 airflow/providers/ssh/hooks/ssh.py                 |  8 ++++++--
 airflow/task/task_runner/base_task_runner.py       |  4 ++--
 airflow/task/task_runner/cgroup_task_runner.py     |  4 ++--
 airflow/utils/cli.py                               |  5 ++---
 airflow/utils/platform.py                          | 23 ++++++++++++++++++++++
 .../endpoints/test_task_instance_endpoint.py       |  8 ++++----
 .../schemas/test_task_instance_schema.py           |  6 +++---
 tests/jobs/test_base_job.py                        |  2 +-
 .../providers/microsoft/winrm/hooks/test_winrm.py  |  2 +-
 .../task/task_runner/test_standard_task_runner.py  |  6 +++---
 14 files changed, 59 insertions(+), 29 deletions(-)

diff --git a/airflow/cli/commands/info_command.py b/airflow/cli/commands/info_command.py
index a0a65b6..2842722 100644
--- a/airflow/cli/commands/info_command.py
+++ b/airflow/cli/commands/info_command.py
@@ -15,7 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 """Config sub-commands"""
-import getpass
 import locale
 import logging
 import os
@@ -33,6 +32,7 @@ from airflow.cli.simple_table import AirflowConsole
 from airflow.providers_manager import ProvidersManager
 from airflow.typing_compat import Protocol
 from airflow.utils.cli import suppress_logs_and_warning
+from airflow.utils.platform import getuser
 from airflow.version import version as airflow_version
 
 log = logging.getLogger(__name__)
@@ -67,7 +67,7 @@ class PiiAnonymizer(Anonymizer):
 
     def __init__(self):
         home_path = os.path.expanduser("~")
-        username = getpass.getuser()
+        username = getuser()
         self._path_replacements = {home_path: "${HOME}", username: "${USER}"}
 
     def process_path(self, value):
diff --git a/airflow/jobs/base_job.py b/airflow/jobs/base_job.py
index f785189..1837c27 100644
--- a/airflow/jobs/base_job.py
+++ b/airflow/jobs/base_job.py
@@ -17,7 +17,6 @@
 # under the License.
 #
 
-import getpass
 from time import sleep
 from typing import Optional
 
@@ -37,6 +36,7 @@ from airflow.utils import timezone
 from airflow.utils.helpers import convert_camel_to_snake
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.net import get_hostname
+from airflow.utils.platform import getuser
 from airflow.utils.session import create_session, provide_session
 from airflow.utils.sqlalchemy import UtcDateTime
 from airflow.utils.state import State
@@ -100,7 +100,7 @@ class BaseJob(Base, LoggingMixin):
         self.latest_heartbeat = timezone.utcnow()
         if heartrate is not None:
             self.heartrate = heartrate
-        self.unixname = getpass.getuser()
+        self.unixname = getuser()
         self.max_tis_per_query = conf.getint('scheduler', 'max_tis_per_query')
         super().__init__(*args, **kwargs)
 
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 4b49177..3b1c015 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -16,7 +16,6 @@
 # specific language governing permissions and limitations
 # under the License.
 import contextlib
-import getpass
 import hashlib
 import logging
 import math
@@ -68,6 +67,7 @@ from airflow.utils.helpers import is_container
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.net import get_hostname
 from airflow.utils.operator_helpers import context_to_airflow_vars
+from airflow.utils.platform import getuser
 from airflow.utils.session import provide_session
 from airflow.utils.sqlalchemy import UtcDateTime, with_row_locks
 from airflow.utils.state import State
@@ -327,7 +327,7 @@ class TaskInstance(Base, LoggingMixin):  # pylint: disable=R0902,R0904
         self.execution_date = execution_date
 
         self.try_number = 0
-        self.unixname = getpass.getuser()
+        self.unixname = getuser()
         if state:
             self.state = state
         self.hostname = ''
diff --git a/airflow/providers/microsoft/winrm/hooks/winrm.py b/airflow/providers/microsoft/winrm/hooks/winrm.py
index 2524a81..64f15a5 100644
--- a/airflow/providers/microsoft/winrm/hooks/winrm.py
+++ b/airflow/providers/microsoft/winrm/hooks/winrm.py
@@ -17,7 +17,6 @@
 # under the License.
 #
 """Hook for winrm remote execution."""
-import getpass
 from typing import Optional
 
 from winrm.protocol import Protocol
@@ -25,6 +24,11 @@ from winrm.protocol import Protocol
 from airflow.exceptions import AirflowException
 from airflow.hooks.base import BaseHook
 
+try:
+    from airflow.utils.platform import getuser
+except ImportError:
+    from getpass import getuser
+
 
 # TODO: Fixme please - I have too complex implementation
 # pylint: disable=too-many-instance-attributes,too-many-arguments,too-many-branches
@@ -201,7 +205,7 @@ class WinRMHook(BaseHook):
                 self.remote_host,
                 self.ssh_conn_id,
             )
-            self.username = getpass.getuser()
+            self.username = getuser()
 
         # If endpoint is not set, then build a standard wsman endpoint from host and port.
         if not self.endpoint:
diff --git a/airflow/providers/ssh/hooks/ssh.py b/airflow/providers/ssh/hooks/ssh.py
index a0d0a3a..d94b2a3 100644
--- a/airflow/providers/ssh/hooks/ssh.py
+++ b/airflow/providers/ssh/hooks/ssh.py
@@ -16,7 +16,6 @@
 # specific language governing permissions and limitations
 # under the License.
 """Hook for SSH connections."""
-import getpass
 import os
 import warnings
 from base64 import decodebytes
@@ -30,6 +29,11 @@ from sshtunnel import SSHTunnelForwarder
 from airflow.exceptions import AirflowException
 from airflow.hooks.base import BaseHook
 
+try:
+    from airflow.utils.platform import getuser
+except ImportError:
+    from getpass import getuser
+
 
 class SSHHook(BaseHook):  # pylint: disable=too-many-instance-attributes
     """
@@ -173,7 +177,7 @@ class SSHHook(BaseHook):  # pylint: disable=too-many-instance-attributes
                 self.remote_host,
                 self.ssh_conn_id,
             )
-            self.username = getpass.getuser()
+            self.username = getuser()
 
         user_ssh_config_filename = os.path.expanduser('~/.ssh/config')
         if os.path.isfile(user_ssh_config_filename):
diff --git a/airflow/task/task_runner/base_task_runner.py b/airflow/task/task_runner/base_task_runner.py
index 81235ea..07e8625 100644
--- a/airflow/task/task_runner/base_task_runner.py
+++ b/airflow/task/task_runner/base_task_runner.py
@@ -16,7 +16,6 @@
 # specific language governing permissions and limitations
 # under the License.
 """Base task runner"""
-import getpass
 import os
 import subprocess
 import threading
@@ -29,6 +28,7 @@ from airflow.models.taskinstance import load_error_file
 from airflow.utils.configuration import tmp_configuration_copy
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.net import get_hostname
+from airflow.utils.platform import getuser
 
 PYTHONPATH_VAR = 'PYTHONPATH'
 
@@ -60,7 +60,7 @@ class BaseTaskRunner(LoggingMixin):
         # Add sudo commands to change user if we need to. Needed to handle SubDagOperator
         # case using a SequentialExecutor.
         self.log.debug("Planning to run as the %s user", self.run_as_user)
-        if self.run_as_user and (self.run_as_user != getpass.getuser()):
+        if self.run_as_user and (self.run_as_user != getuser()):
             # We want to include any environment variables now, as we won't
             # want to have to specify them in the sudo call - they would show
             # up in `ps` that way! And run commands now, as the other user
diff --git a/airflow/task/task_runner/cgroup_task_runner.py b/airflow/task/task_runner/cgroup_task_runner.py
index 01360dd..04b3a23 100644
--- a/airflow/task/task_runner/cgroup_task_runner.py
+++ b/airflow/task/task_runner/cgroup_task_runner.py
@@ -19,7 +19,6 @@
 """Task runner for cgroup to run Airflow task"""
 
 import datetime
-import getpass
 import os
 import uuid
 
@@ -28,6 +27,7 @@ from cgroupspy import trees
 
 from airflow.task.task_runner.base_task_runner import BaseTaskRunner
 from airflow.utils.operator_resources import Resources
+from airflow.utils.platform import getuser
 from airflow.utils.process_utils import reap_process_group
 
 
@@ -70,7 +70,7 @@ class CgroupTaskRunner(BaseTaskRunner):
         self.cpu_cgroup_name = None
         self._created_cpu_cgroup = False
         self._created_mem_cgroup = False
-        self._cur_user = getpass.getuser()
+        self._cur_user = getuser()
 
     def _create_cgroup(self, path):
         """
diff --git a/airflow/utils/cli.py b/airflow/utils/cli.py
index b64b8a8..80973b7 100644
--- a/airflow/utils/cli.py
+++ b/airflow/utils/cli.py
@@ -18,7 +18,6 @@
 #
 """Utilities module for cli"""
 import functools
-import getpass
 import json
 import logging
 import os
@@ -35,7 +34,7 @@ from typing import TYPE_CHECKING, Callable, Optional, TypeVar, cast
 from airflow import settings
 from airflow.exceptions import AirflowException
 from airflow.utils import cli_action_loggers
-from airflow.utils.platform import is_terminal_support_colors
+from airflow.utils.platform import getuser, is_terminal_support_colors
 from airflow.utils.session import provide_session
 
 T = TypeVar("T", bound=Callable)  # pylint: disable=invalid-name
@@ -131,7 +130,7 @@ def _build_metrics(func_name, namespace):
         'sub_command': func_name,
         'start_datetime': datetime.utcnow(),
         'full_command': f'{full_command}',
-        'user': getpass.getuser(),
+        'user': getuser(),
     }
 
     if not isinstance(namespace, Namespace):
diff --git a/airflow/utils/platform.py b/airflow/utils/platform.py
index 56087b9..73eb609 100644
--- a/airflow/utils/platform.py
+++ b/airflow/utils/platform.py
@@ -16,6 +16,7 @@
 # under the License.
 
 """Platform and system specific function."""
+import getpass
 import logging
 import os
 import pkgutil
@@ -57,3 +58,25 @@ def get_airflow_git_version():
         log.debug(e)
 
     return git_version
+
+
+def getuser() -> str:
+    """
+    Gets the username associated with the current user, or error with a nice
+    error message if there's no current user.
+
+    We don't want to fall back to os.getuid() because not having a username
+    probably means the rest of the user environment is wrong (e.g. no $HOME).
+    Explicit failure is better than silently trying to work badly.
+    """
+    try:
+        return getpass.getuser()
+    except KeyError:
+        # Inner import to avoid circular import
+        from airflow.exceptions import AirflowConfigException
+
+        raise AirflowConfigException(
+            "The user that Airflow is running as has no username; you must run"
+            "Airflow as a full user, with a username and home directory, "
+            "in order for it to function properly."
+        )
diff --git a/tests/api_connexion/endpoints/test_task_instance_endpoint.py b/tests/api_connexion/endpoints/test_task_instance_endpoint.py
index 307befd..6c0ec55 100644
--- a/tests/api_connexion/endpoints/test_task_instance_endpoint.py
+++ b/tests/api_connexion/endpoints/test_task_instance_endpoint.py
@@ -15,7 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 import datetime as dt
-import getpass
 from unittest import mock
 
 import pytest
@@ -23,6 +22,7 @@ from parameterized import parameterized
 
 from airflow.models import DagBag, DagRun, SlaMiss, TaskInstance
 from airflow.security import permissions
+from airflow.utils.platform import getuser
 from airflow.utils.session import provide_session
 from airflow.utils.state import State
 from airflow.utils.timezone import datetime
@@ -160,7 +160,7 @@ class TestGetTaskInstance(TestTaskInstanceEndpoint):
             "state": "running",
             "task_id": "print_the_context",
             "try_number": 0,
-            "unixname": getpass.getuser(),
+            "unixname": getuser(),
         }
 
     def test_should_respond_200_with_task_state_in_removed(self, session):
@@ -190,7 +190,7 @@ class TestGetTaskInstance(TestTaskInstanceEndpoint):
             "state": "removed",
             "task_id": "print_the_context",
             "try_number": 0,
-            "unixname": getpass.getuser(),
+            "unixname": getuser(),
         }
 
     def test_should_respond_200_task_instance_with_sla(self, session):
@@ -238,7 +238,7 @@ class TestGetTaskInstance(TestTaskInstanceEndpoint):
             "state": "running",
             "task_id": "print_the_context",
             "try_number": 0,
-            "unixname": getpass.getuser(),
+            "unixname": getuser(),
         }
 
     def test_should_raises_401_unauthenticated(self):
diff --git a/tests/api_connexion/schemas/test_task_instance_schema.py b/tests/api_connexion/schemas/test_task_instance_schema.py
index c28e47d..73895ae 100644
--- a/tests/api_connexion/schemas/test_task_instance_schema.py
+++ b/tests/api_connexion/schemas/test_task_instance_schema.py
@@ -16,7 +16,6 @@
 # under the License.
 
 import datetime as dt
-import getpass
 import unittest
 
 import pytest
@@ -30,6 +29,7 @@ from airflow.api_connexion.schemas.task_instance_schema import (
 )
 from airflow.models import DAG, SlaMiss, TaskInstance as TI
 from airflow.operators.dummy import DummyOperator
+from airflow.utils.platform import getuser
 from airflow.utils.session import create_session, provide_session
 from airflow.utils.state import State
 from airflow.utils.timezone import datetime
@@ -87,7 +87,7 @@ class TestTaskInstanceSchema(unittest.TestCase):
             "state": "running",
             "task_id": "TEST_TASK_ID",
             "try_number": 0,
-            "unixname": getpass.getuser(),
+            "unixname": getuser(),
         }
         assert serialized_ti == expected_json
 
@@ -133,7 +133,7 @@ class TestTaskInstanceSchema(unittest.TestCase):
             "state": "running",
             "task_id": "TEST_TASK_ID",
             "try_number": 0,
-            "unixname": getpass.getuser(),
+            "unixname": getuser(),
         }
         assert serialized_ti == expected_json
 
diff --git a/tests/jobs/test_base_job.py b/tests/jobs/test_base_job.py
index 3ffe4ad..093386b 100644
--- a/tests/jobs/test_base_job.py
+++ b/tests/jobs/test_base_job.py
@@ -121,7 +121,7 @@ class TestBaseJob:
     @conf_vars({('scheduler', 'max_tis_per_query'): '100'})
     @patch('airflow.jobs.base_job.ExecutorLoader.get_default_executor')
     @patch('airflow.jobs.base_job.get_hostname')
-    @patch('airflow.jobs.base_job.getpass.getuser')
+    @patch('airflow.jobs.base_job.getuser')
     def test_essential_attr(self, mock_getuser, mock_hostname, mock_default_executor):
         mock_sequential_executor = SequentialExecutor()
         mock_hostname.return_value = "test_hostname"
diff --git a/tests/providers/microsoft/winrm/hooks/test_winrm.py b/tests/providers/microsoft/winrm/hooks/test_winrm.py
index 042a2cd..aa598bf 100644
--- a/tests/providers/microsoft/winrm/hooks/test_winrm.py
+++ b/tests/providers/microsoft/winrm/hooks/test_winrm.py
@@ -102,7 +102,7 @@ class TestWinRMHook(unittest.TestCase):
             send_cbt=str(connection.extra_dejson['send_cbt']).lower() == 'true',
         )
 
-    @patch('airflow.providers.microsoft.winrm.hooks.winrm.getpass.getuser', return_value='user')
+    @patch('airflow.providers.microsoft.winrm.hooks.winrm.getuser', return_value='user')
     @patch('airflow.providers.microsoft.winrm.hooks.winrm.Protocol')
     def test_get_conn_no_username(self, mock_protocol, mock_getuser):
         winrm_hook = WinRMHook(remote_host='host', password='password')
diff --git a/tests/task/task_runner/test_standard_task_runner.py b/tests/task/task_runner/test_standard_task_runner.py
index 6a3ab5d..35fb41b 100644
--- a/tests/task/task_runner/test_standard_task_runner.py
+++ b/tests/task/task_runner/test_standard_task_runner.py
@@ -15,7 +15,6 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-import getpass
 import logging
 import os
 import time
@@ -30,6 +29,7 @@ from airflow.jobs.local_task_job import LocalTaskJob
 from airflow.models import TaskInstance as TI
 from airflow.task.task_runner.standard_task_runner import StandardTaskRunner
 from airflow.utils import timezone
+from airflow.utils.platform import getuser
 from airflow.utils.state import State
 from tests.test_utils.db import clear_db_runs
 
@@ -105,7 +105,7 @@ class TestStandardTaskRunner:
     def test_start_and_terminate_run_as_user(self):
         local_task_job = mock.Mock()
         local_task_job.task_instance = mock.MagicMock()
-        local_task_job.task_instance.run_as_user = getpass.getuser()
+        local_task_job.task_instance.run_as_user = getuser()
         local_task_job.task_instance.command_as_list.return_value = [
             'airflow',
             'tasks',
@@ -142,7 +142,7 @@ class TestStandardTaskRunner:
         # Set up mock task
         local_task_job = mock.Mock()
         local_task_job.task_instance = mock.MagicMock()
-        local_task_job.task_instance.run_as_user = getpass.getuser()
+        local_task_job.task_instance.run_as_user = getuser()
         local_task_job.task_instance.command_as_list.return_value = [
             'airflow',
             'tasks',