You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/03/06 18:57:58 UTC

[airflow] branch main updated: Fixing bug when roles list is empty (#18590)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9a9d54b  Fixing bug when roles list is empty (#18590)
9a9d54b is described below

commit 9a9d54ba98f1578ff0164125d6a8f916124d3e21
Author: Khalid Mammadov <xm...@hotmail.com>
AuthorDate: Sun Mar 6 18:57:14 2022 +0000

    Fixing bug when roles list is empty (#18590)
---
 airflow/cli/commands/user_command.py    | 33 +++++++++++++---
 tests/cli/commands/test_user_command.py | 69 +++++++++++++++++++++++++++------
 tests/test_utils/api_connexion_utils.py |  5 +++
 3 files changed, 90 insertions(+), 17 deletions(-)

diff --git a/airflow/cli/commands/user_command.py b/airflow/cli/commands/user_command.py
index 5cbfe2b..85c11d6 100644
--- a/airflow/cli/commands/user_command.py
+++ b/airflow/cli/commands/user_command.py
@@ -22,6 +22,10 @@ import os
 import random
 import re
 import string
+from typing import Any, Dict, List
+
+from marshmallow import Schema, fields, validate
+from marshmallow.exceptions import ValidationError
 
 from airflow.cli.simple_table import AirflowConsole
 from airflow.utils import cli as cli_utils
@@ -29,6 +33,17 @@ from airflow.utils.cli import suppress_logs_and_warning
 from airflow.www.app import cached_app
 
 
+class UserSchema(Schema):
+    """user collection item schema"""
+
+    id = fields.Int()
+    firstname = fields.Str(required=True)
+    lastname = fields.Str(required=True)
+    username = fields.Str(required=True)
+    email = fields.Email(required=True)
+    roles = fields.List(fields.Str, required=True, validate=validate.Length(min=1))
+
+
 @suppress_logs_and_warning
 def users_list(args):
     """Lists users at the command line"""
@@ -174,12 +189,23 @@ def users_import(args):
         print("Updated the following users:\n\t{}".format("\n\t".join(users_updated)))
 
 
-def _import_users(users_list):
+def _import_users(users_list: List[Dict[str, Any]]):
     appbuilder = cached_app().appbuilder
     users_created = []
     users_updated = []
 
+    try:
+        UserSchema(many=True).load(users_list)
+    except ValidationError as e:
+        msg = []
+        for row_num, failure in e.messages.items():
+            msg.append(f'[Item {row_num}]')
+            for key, value in failure.items():
+                msg.append(f'\t{key}: {value}')
+        raise SystemExit("Error: Input file didn't pass validation. See below:\n{}".format('\n'.join(msg)))
+
     for user in users_list:
+
         roles = []
         for rolename in user['roles']:
             role = appbuilder.sm.find_role(rolename)
@@ -189,11 +215,6 @@ def _import_users(users_list):
 
             roles.append(role)
 
-        required_fields = ['username', 'firstname', 'lastname', 'email', 'roles']
-        for field in required_fields:
-            if not user.get(field):
-                raise SystemExit(f"Error: '{field}' is a required field, but was not specified")
-
         existing_user = appbuilder.sm.find_user(email=user['email'])
         if existing_user:
             print(f"Found existing user with email '{user['email']}'")
diff --git a/tests/cli/commands/test_user_command.py b/tests/cli/commands/test_user_command.py
index a03cf97..234c0c8 100644
--- a/tests/cli/commands/test_user_command.py
+++ b/tests/cli/commands/test_user_command.py
@@ -18,15 +18,18 @@
 import io
 import json
 import os
+import re
 import tempfile
 from contextlib import redirect_stdout
 
 import pytest
 
 from airflow.cli.commands import user_command
+from tests.test_utils.api_connexion_utils import delete_users
 
 TEST_USER1_EMAIL = 'test-user1@example.com'
 TEST_USER2_EMAIL = 'test-user2@example.com'
+TEST_USER3_EMAIL = 'test-user3@example.com'
 
 
 def _does_user_belong_to_role(appbuilder, email, rolename):
@@ -45,18 +48,9 @@ class TestCliUsers:
         self.dagbag = dagbag
         self.parser = parser
         self.appbuilder = self.app.appbuilder
-        self.clear_roles_and_roles()
+        delete_users(app)
         yield
-        self.clear_roles_and_roles()
-
-    def clear_roles_and_roles(self):
-        for email in [TEST_USER1_EMAIL, TEST_USER2_EMAIL]:
-            test_user = self.appbuilder.sm.find_user(email=email)
-            if test_user:
-                self.appbuilder.sm.del_register_user(test_user)
-        for role_name in ['FakeTeamA', 'FakeTeamB']:
-            if self.appbuilder.sm.find_role(role_name):
-                self.appbuilder.sm.delete_role(role_name)
+        delete_users(app)
 
     def test_cli_create_user_random_password(self):
         args = self.parser.parse_args(
@@ -411,3 +405,56 @@ class TestCliUsers:
                 user_command.add_role(args)
             else:
                 user_command.remove_role(args)
+
+    @pytest.mark.parametrize(
+        "user, message",
+        [
+            [
+                {
+                    "username": "imported_user1",
+                    "lastname": "doe1",
+                    "firstname": "john",
+                    "email": TEST_USER1_EMAIL,
+                    "roles": "This is not a list",
+                },
+                "Error: Input file didn't pass validation. See below:\n"
+                "[Item 0]\n"
+                "\troles: ['Not a valid list.']",
+            ],
+            [
+                {
+                    "username": "imported_user2",
+                    "lastname": "doe2",
+                    "firstname": "jon",
+                    "email": TEST_USER2_EMAIL,
+                    "roles": [],
+                },
+                "Error: Input file didn't pass validation. See below:\n"
+                "[Item 0]\n"
+                "\troles: ['Shorter than minimum length 1.']",
+            ],
+            [
+                {
+                    "username1": "imported_user3",
+                    "lastname": "doe3",
+                    "firstname": "jon",
+                    "email": TEST_USER3_EMAIL,
+                    "roles": ["Test"],
+                },
+                "Error: Input file didn't pass validation. See below:\n"
+                "[Item 0]\n"
+                "\tusername: ['Missing data for required field.']\n"
+                "\tusername1: ['Unknown field.']",
+            ],
+            [
+                "Wrong input",
+                "Error: Input file didn't pass validation. See below:\n"
+                "[Item 0]\n"
+                "\t_schema: ['Invalid input type.']",
+            ],
+        ],
+        ids=["Incorrect roles", "Empty roles", "Required field is missing", "Wrong input"],
+    )
+    def test_cli_import_users_exceptions(self, user, message):
+        with pytest.raises(SystemExit, match=re.escape(message)):
+            self._import_users_from_file([user])
diff --git a/tests/test_utils/api_connexion_utils.py b/tests/test_utils/api_connexion_utils.py
index e8063f2..f4553c6 100644
--- a/tests/test_utils/api_connexion_utils.py
+++ b/tests/test_utils/api_connexion_utils.py
@@ -110,6 +110,11 @@ def delete_user(app, username):
             break
 
 
+def delete_users(app):
+    for user in app.appbuilder.sm.get_all_users():
+        delete_user(app, user.username)
+
+
 def assert_401(response):
     assert response.status_code == 401, f"Current code: {response.status_code}"
     assert response.json == {