You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2020/06/02 10:01:08 UTC
[airflow] branch master updated: Allow using Airflow with Flask CLI
(#9030)
This is an automated email from the ASF dual-hosted git repository.
kamilbregula pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new 87a4a0a Allow using Airflow with Flask CLI (#9030)
87a4a0a is described below
commit 87a4a0adff037fcfa299fdffd0877a8fdacdd428
Author: Kamil BreguĊa <mi...@users.noreply.github.com>
AuthorDate: Tue Jun 2 12:00:17 2020 +0200
Allow using Airflow with Flask CLI (#9030)
---
airflow/cli/commands/role_command.py | 6 +-
airflow/cli/commands/sync_perm_command.py | 4 +-
airflow/cli/commands/user_command.py | 14 ++---
airflow/cli/commands/webserver_command.py | 2 +-
airflow/www/app.py | 69 +++++++++++-----------
airflow/www/security.py | 7 +--
airflow/www/views.py | 21 ++++---
tests/cli/commands/test_celery_command.py | 5 +-
tests/cli/commands/test_role_command.py | 3 +-
tests/cli/commands/test_sync_perm_command.py | 19 +++---
tests/cli/commands/test_task_command.py | 5 +-
tests/cli/commands/test_user_command.py | 3 +-
tests/plugins/test_plugins_manager.py | 3 +-
.../www/api/experimental/test_dag_runs_endpoint.py | 2 +-
tests/www/api/experimental/test_endpoints.py | 3 +-
.../api/experimental/test_kerberos_endpoints.py | 2 +-
tests/www/test_app.py | 24 +++-----
tests/www/test_utils.py | 12 ++--
tests/www/test_views.py | 6 +-
19 files changed, 102 insertions(+), 108 deletions(-)
diff --git a/airflow/cli/commands/role_command.py b/airflow/cli/commands/role_command.py
index 802e060..b4e6f59 100644
--- a/airflow/cli/commands/role_command.py
+++ b/airflow/cli/commands/role_command.py
@@ -20,12 +20,12 @@
from tabulate import tabulate
from airflow.utils import cli as cli_utils
-from airflow.www.app import cached_appbuilder
+from airflow.www.app import cached_app
def roles_list(args):
"""Lists all existing roles"""
- appbuilder = cached_appbuilder()
+ appbuilder = cached_app().appbuilder # pylint: disable=no-member
roles = appbuilder.sm.get_all_roles()
print("Existing roles:\n")
role_names = sorted([[r.name] for r in roles])
@@ -38,6 +38,6 @@ def roles_list(args):
@cli_utils.action_logging
def roles_create(args):
"""Creates new empty role in DB"""
- appbuilder = cached_appbuilder()
+ appbuilder = cached_app().appbuilder # pylint: disable=no-member
for role_name in args.role:
appbuilder.sm.add_role(role_name)
diff --git a/airflow/cli/commands/sync_perm_command.py b/airflow/cli/commands/sync_perm_command.py
index 1d435b9..cf490b3 100644
--- a/airflow/cli/commands/sync_perm_command.py
+++ b/airflow/cli/commands/sync_perm_command.py
@@ -18,13 +18,13 @@
"""Sync permission command"""
from airflow.models import DagBag
from airflow.utils import cli as cli_utils
-from airflow.www.app import cached_appbuilder
+from airflow.www.app import cached_app
@cli_utils.action_logging
def sync_perm(args):
"""Updates permissions for existing roles and DAGs"""
- appbuilder = cached_appbuilder()
+ appbuilder = cached_app().appbuilder # pylint: disable=no-member
print('Updating permission, view-menu for all existing roles')
appbuilder.sm.sync_roles()
print('Updating permission on all DAG views')
diff --git a/airflow/cli/commands/user_command.py b/airflow/cli/commands/user_command.py
index a3db7bb..883f349 100644
--- a/airflow/cli/commands/user_command.py
+++ b/airflow/cli/commands/user_command.py
@@ -27,12 +27,12 @@ import sys
from tabulate import tabulate
from airflow.utils import cli as cli_utils
-from airflow.www.app import cached_appbuilder
+from airflow.www.app import cached_app
def users_list(args):
"""Lists users at the command line"""
- appbuilder = cached_appbuilder()
+ appbuilder = cached_app().appbuilder # pylint: disable=no-member
users = appbuilder.sm.get_all_users()
fields = ['id', 'username', 'email', 'first_name', 'last_name', 'roles']
users = [[user.__getattribute__(field) for field in fields] for user in users]
@@ -44,7 +44,7 @@ def users_list(args):
@cli_utils.action_logging
def users_create(args):
"""Creates new user in the DB"""
- appbuilder = cached_appbuilder()
+ appbuilder = cached_app().appbuilder # pylint: disable=no-member
role = appbuilder.sm.find_role(args.role)
if not role:
valid_roles = appbuilder.sm.get_all_roles()
@@ -74,7 +74,7 @@ def users_create(args):
@cli_utils.action_logging
def users_delete(args):
"""Deletes user from DB"""
- appbuilder = cached_appbuilder()
+ appbuilder = cached_app().appbuilder # pylint: disable=no-member
try:
user = next(u for u in appbuilder.sm.get_all_users()
@@ -98,7 +98,7 @@ def users_manage_role(args, remove=False):
raise SystemExit('Conflicting args: must supply either --username'
' or --email, but not both')
- appbuilder = cached_appbuilder()
+ appbuilder = cached_app().appbuilder # pylint: disable=no-member
user = (appbuilder.sm.find_user(username=args.username) or
appbuilder.sm.find_user(email=args.email))
if not user:
@@ -136,7 +136,7 @@ def users_manage_role(args, remove=False):
def users_export(args):
"""Exports all users to the json file"""
- appbuilder = cached_appbuilder()
+ appbuilder = cached_app().appbuilder # pylint: disable=no-member
users = appbuilder.sm.get_all_users()
fields = ['id', 'username', 'email', 'first_name', 'last_name', 'roles']
@@ -184,7 +184,7 @@ def users_import(args):
def _import_users(users_list): # pylint: disable=redefined-outer-name
- appbuilder = cached_appbuilder()
+ appbuilder = cached_app().appbuilder # pylint: disable=no-member
users_created = []
users_updated = []
diff --git a/airflow/cli/commands/webserver_command.py b/airflow/cli/commands/webserver_command.py
index 8724cea..cbbdce6 100644
--- a/airflow/cli/commands/webserver_command.py
+++ b/airflow/cli/commands/webserver_command.py
@@ -194,7 +194,7 @@ def webserver(args):
print(
"Starting the web server on port {0} and host {1}.".format(
args.port, args.hostname))
- app, _ = create_app(testing=conf.getboolean('core', 'unit_test_mode'))
+ app = create_app(testing=conf.getboolean('core', 'unit_test_mode'))
app.run(debug=True, use_reloader=not app.config['TESTING'],
port=args.port, host=args.hostname,
ssl_context=(ssl_cert, ssl_key) if ssl_cert and ssl_key else None)
diff --git a/airflow/www/app.py b/airflow/www/app.py
index 5e9fbdd..eb5589f 100644
--- a/airflow/www/app.py
+++ b/airflow/www/app.py
@@ -20,7 +20,7 @@ import datetime
import logging
import socket
from datetime import timedelta
-from typing import Any, Optional
+from typing import Optional
from urllib.parse import urlparse
import flask
@@ -39,15 +39,18 @@ from airflow.logging_config import configure_logging
from airflow.utils.json import AirflowJsonEncoder
from airflow.www.static_config import configure_manifest_files
-app = None # type: Any
-appbuilder = None # type: Optional[AppBuilder]
+app: Optional[Flask] = None
csrf = CSRFProtect()
log = logging.getLogger(__name__)
+def root_app(env, resp):
+ resp(b'404 Not Found', [('Content-Type', 'text/plain')])
+ return [b'Apache Airflow is not at this location']
+
+
def create_app(config=None, testing=False, app_name="Airflow"):
- global app, appbuilder
app = Flask(__name__)
app.secret_key = conf.get('webserver', 'SECRET_KEY')
@@ -70,6 +73,31 @@ def create_app(config=None, testing=False, app_name="Airflow"):
app.json_encoder = AirflowJsonEncoder
csrf.init_app(app)
+
+ def apply_middlewares(flask_app: Flask):
+ # Apply DispatcherMiddleware
+ base_url = urlparse(conf.get('webserver', 'base_url'))[2]
+ if not base_url or base_url == '/':
+ base_url = ""
+ if base_url:
+ flask_app.wsgi_app = DispatcherMiddleware( # type: ignore
+ root_app,
+ mounts={base_url: flask_app.wsgi_app}
+ )
+
+ # Apply ProxyFix middleware
+ if conf.getboolean('webserver', 'ENABLE_PROXY_FIX'):
+ flask_app.wsgi_app = ProxyFix( # type: ignore
+ flask_app.wsgi_app,
+ x_for=conf.getint("webserver", "PROXY_FIX_X_FOR", fallback=1),
+ x_proto=conf.getint("webserver", "PROXY_FIX_X_PROTO", fallback=1),
+ x_host=conf.getint("webserver", "PROXY_FIX_X_HOST", fallback=1),
+ x_port=conf.getint("webserver", "PROXY_FIX_X_PORT", fallback=1),
+ x_prefix=conf.getint("webserver", "PROXY_FIX_X_PREFIX", fallback=1)
+ )
+
+ apply_middlewares(app)
+
db = SQLA()
db.session = settings.Session
db.init_app(app)
@@ -286,36 +314,11 @@ def create_app(config=None, testing=False, app_name="Airflow"):
def make_session_permanent():
flask_session.permanent = True
- return app, appbuilder
-
-
-def root_app(env, resp):
- resp(b'404 Not Found', [('Content-Type', 'text/plain')])
- return [b'Apache Airflow is not at this location']
+ return app
def cached_app(config=None, testing=False):
- global app, appbuilder
- if not app or not appbuilder:
- base_url = urlparse(conf.get('webserver', 'base_url'))[2]
- if not base_url or base_url == '/':
- base_url = ""
-
- app, _ = create_app(config=config, testing=testing)
- app = DispatcherMiddleware(root_app, {base_url: app})
- if conf.getboolean('webserver', 'ENABLE_PROXY_FIX'):
- app = ProxyFix(
- app,
- x_for=conf.getint("webserver", "PROXY_FIX_X_FOR", fallback=1),
- x_proto=conf.getint("webserver", "PROXY_FIX_X_PROTO", fallback=1),
- x_host=conf.getint("webserver", "PROXY_FIX_X_HOST", fallback=1),
- x_port=conf.getint("webserver", "PROXY_FIX_X_PORT", fallback=1),
- x_prefix=conf.getint("webserver", "PROXY_FIX_X_PREFIX", fallback=1)
- )
+ global app
+ if not app:
+ app = create_app(config=config, testing=testing)
return app
-
-
-def cached_appbuilder(config=None, testing=False):
- global appbuilder
- cached_app(config=config, testing=testing)
- return appbuilder
diff --git a/airflow/www/security.py b/airflow/www/security.py
index aa4d532..05e15b8 100644
--- a/airflow/www/security.py
+++ b/airflow/www/security.py
@@ -17,7 +17,7 @@
# under the License.
#
-from flask import g
+from flask import current_app, g
from flask_appbuilder.security.sqla import models as sqla_models
from flask_appbuilder.security.sqla.manager import SecurityManager
from sqlalchemy import and_, or_
@@ -26,7 +26,6 @@ from airflow import models
from airflow.exceptions import AirflowException
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import provide_session
-from airflow.www.app import appbuilder
from airflow.www.utils import CustomSQLAInterface
EXISTING_ROLES = {
@@ -250,8 +249,8 @@ class AirflowSecurityManager(SecurityManager, LoggingMixin):
if user is None:
user = g.user
if user.is_anonymous:
- public_role = appbuilder.config.get('AUTH_ROLE_PUBLIC')
- return [appbuilder.security_manager.find_role(public_role)] \
+ public_role = current_app.appbuilder.config.get('AUTH_ROLE_PUBLIC')
+ return [current_app.appbuilder.security_manager.find_role(public_role)] \
if public_role else []
return user.roles
diff --git a/airflow/www/views.py b/airflow/www/views.py
index 5d427bf..76ed135 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -35,7 +35,7 @@ import lazy_object_proxy
import markdown
import sqlalchemy as sqla
from flask import (
- Markup, Response, escape, flash, jsonify, make_response, redirect, render_template, request,
+ Markup, Response, current_app, escape, flash, jsonify, make_response, redirect, render_template, request,
session as flask_session, url_for,
)
from flask_appbuilder import BaseView, ModelView, expose, has_access, permission_name
@@ -72,7 +72,6 @@ from airflow.utils.helpers import alchemy_to_dict, render_log_filename
from airflow.utils.session import create_session, provide_session
from airflow.utils.state import State
from airflow.www import utils as wwwutils
-from airflow.www.app import appbuilder
from airflow.www.decorators import action_logging, gzipped, has_dag_access
from airflow.www.forms import (
ConnectionForm, DagRunForm, DateTimeForm, DateTimeWithNumRunsForm, DateTimeWithNumRunsWithDagRunsForm,
@@ -270,7 +269,7 @@ class Airflow(AirflowBaseView):
end = start + dags_per_page
# Get all the dag id the user could access
- filter_dag_ids = appbuilder.sm.get_accessible_dag_ids()
+ filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids()
with create_session() as session:
# read orm_dags from the db
@@ -368,7 +367,7 @@ class Airflow(AirflowBaseView):
def dag_stats(self, session=None):
dr = models.DagRun
- allowed_dag_ids = appbuilder.sm.get_accessible_dag_ids()
+ allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids()
if 'all_dags' in allowed_dag_ids:
allowed_dag_ids = [dag_id for dag_id, in session.query(models.DagModel.dag_id)]
@@ -416,7 +415,7 @@ class Airflow(AirflowBaseView):
DagRun = models.DagRun
Dag = models.DagModel
- allowed_dag_ids = set(appbuilder.sm.get_accessible_dag_ids())
+ allowed_dag_ids = set(current_app.appbuilder.sm.get_accessible_dag_ids())
if not allowed_dag_ids:
return wwwutils.json_response({})
@@ -512,7 +511,7 @@ class Airflow(AirflowBaseView):
def last_dagruns(self, session=None):
DagRun = models.DagRun
- allowed_dag_ids = appbuilder.sm.get_accessible_dag_ids()
+ allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids()
if 'all_dags' in allowed_dag_ids:
allowed_dag_ids = [dag_id for dag_id, in session.query(models.DagModel.dag_id)]
@@ -1167,7 +1166,7 @@ class Airflow(AirflowBaseView):
@has_access
@provide_session
def blocked(self, session=None):
- allowed_dag_ids = appbuilder.sm.get_accessible_dag_ids()
+ allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids()
if 'all_dags' in allowed_dag_ids:
allowed_dag_ids = [dag_id for dag_id, in session.query(models.DagModel.dag_id)]
@@ -1912,7 +1911,7 @@ class Airflow(AirflowBaseView):
dag = dagbag.get_dag(dag_id)
# sync dag permission
- appbuilder.sm.sync_perm_for_dag(dag_id, dag.access_control)
+ current_app.appbuilder.sm.sync_perm_for_dag(dag_id, dag.access_control)
flash("DAG [{}] is now fresh as a daisy".format(dag_id))
return redirect(request.referrer)
@@ -2163,9 +2162,9 @@ class ConfigurationView(AirflowBaseView):
class DagFilter(BaseFilter):
def apply(self, query, func): # noqa
- if appbuilder.sm.has_all_dags_access():
+ if current_app.appbuilder.sm.has_all_dags_access():
return query
- filter_dag_ids = appbuilder.sm.get_accessible_dag_ids()
+ filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids()
return query.filter(self.model.dag_id.in_(filter_dag_ids))
@@ -2800,7 +2799,7 @@ class DagModelView(AirflowModelView):
dag_ids_query = dag_ids_query.filter(DagModel.is_paused)
owners_query = owners_query.filter(DagModel.is_paused)
- filter_dag_ids = appbuilder.sm.get_accessible_dag_ids()
+ filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids()
if 'all_dags' not in filter_dag_ids:
dag_ids_query = dag_ids_query.filter(DagModel.dag_id.in_(filter_dag_ids))
owners_query = owners_query.filter(DagModel.dag_id.in_(filter_dag_ids))
diff --git a/tests/cli/commands/test_celery_command.py b/tests/cli/commands/test_celery_command.py
index df7d133..23c2c10 100644
--- a/tests/cli/commands/test_celery_command.py
+++ b/tests/cli/commands/test_celery_command.py
@@ -29,9 +29,6 @@ from airflow.cli.commands import celery_command
from airflow.configuration import conf
from tests.test_utils.config import conf_vars
-mock.patch('airflow.utils.cli.action_logging', lambda x: x).start()
-mock_args = Namespace(queues=1, concurrency=1)
-
class TestWorkerPrecheck(unittest.TestCase):
@mock.patch('airflow.settings.validate_session')
@@ -42,7 +39,7 @@ class TestWorkerPrecheck(unittest.TestCase):
"""
mock_validate_session.return_value = False
with self.assertRaises(SystemExit) as cm:
- celery_command.worker(mock_args)
+ celery_command.worker(Namespace(queues=1, concurrency=1))
self.assertEqual(cm.exception.code, 1)
@conf_vars({('core', 'worker_precheck'): 'False'})
diff --git a/tests/cli/commands/test_role_command.py b/tests/cli/commands/test_role_command.py
index 38f1a22..20af879 100644
--- a/tests/cli/commands/test_role_command.py
+++ b/tests/cli/commands/test_role_command.py
@@ -36,7 +36,8 @@ class TestCliRoles(unittest.TestCase):
def setUp(self):
from airflow.www import app as application
- self.app, self.appbuilder = application.create_app(testing=True)
+ self.app = application.create_app(testing=True)
+ self.appbuilder = self.app.appbuilder # pylint: disable=no-member
self.clear_roles_and_roles()
def tearDown(self):
diff --git a/tests/cli/commands/test_sync_perm_command.py b/tests/cli/commands/test_sync_perm_command.py
index ee4c477..88753c7 100644
--- a/tests/cli/commands/test_sync_perm_command.py
+++ b/tests/cli/commands/test_sync_perm_command.py
@@ -31,12 +31,9 @@ class TestCliSyncPerm(unittest.TestCase):
cls.dagbag = DagBag(include_examples=True)
cls.parser = cli_parser.get_parser()
- def setUp(self):
- from airflow.www import app as application
- self.app, self.appbuilder = application.create_app(testing=True)
-
+ @mock.patch("airflow.cli.commands.sync_perm_command.cached_app")
@mock.patch("airflow.cli.commands.sync_perm_command.DagBag")
- def test_cli_sync_perm(self, dagbag_mock):
+ def test_cli_sync_perm(self, dagbag_mock, mock_cached_app):
self.expect_dagbag_contains([
DAG('has_access_control',
access_control={
@@ -44,22 +41,22 @@ class TestCliSyncPerm(unittest.TestCase):
}),
DAG('no_access_control')
], dagbag_mock)
- self.appbuilder.sm = mock.Mock()
+ appbuilder = mock_cached_app.return_value.appbuilder
+ appbuilder.sm = mock.Mock()
args = self.parser.parse_args([
'sync_perm'
])
sync_perm_command.sync_perm(args)
- assert self.appbuilder.sm.sync_roles.call_count == 1
+ assert appbuilder.sm.sync_roles.call_count == 1
- self.assertEqual(2,
- len(self.appbuilder.sm.sync_perm_for_dag.mock_calls))
- self.appbuilder.sm.sync_perm_for_dag.assert_any_call(
+ self.assertEqual(2, len(appbuilder.sm.sync_perm_for_dag.mock_calls))
+ appbuilder.sm.sync_perm_for_dag.assert_any_call(
'has_access_control',
{'Public': {'can_dag_read'}}
)
- self.appbuilder.sm.sync_perm_for_dag.assert_any_call(
+ appbuilder.sm.sync_perm_for_dag.assert_any_call(
'no_access_control',
None,
)
diff --git a/tests/cli/commands/test_task_command.py b/tests/cli/commands/test_task_command.py
index 93a9946..7afabef 100644
--- a/tests/cli/commands/test_task_command.py
+++ b/tests/cli/commands/test_task_command.py
@@ -170,9 +170,10 @@ class TestCliTasks(unittest.TestCase):
def test_task_states_for_dag_run(self):
dag2 = DagBag().dags['example_python_operator']
-
task2 = dag2.get_task(task_id='print_the_context')
defaut_date2 = timezone.make_aware(datetime(2016, 1, 9))
+ dag2.clear()
+
ti2 = TaskInstance(task2, defaut_date2)
ti2.set_state(State.SUCCESS)
@@ -201,7 +202,7 @@ class TestCliTasks(unittest.TestCase):
tablefmt="plain")
# Check that prints, and log messages, are shown
- self.assertEqual(expected.replace("\n", ""), actual_out.replace("\n", ""))
+ self.assertIn(expected.replace("\n", ""), actual_out.replace("\n", ""))
def test_subdag_clear(self):
args = self.parser.parse_args([
diff --git a/tests/cli/commands/test_user_command.py b/tests/cli/commands/test_user_command.py
index 9dfbb3d..ce7573e 100644
--- a/tests/cli/commands/test_user_command.py
+++ b/tests/cli/commands/test_user_command.py
@@ -47,7 +47,8 @@ class TestCliUsers(unittest.TestCase):
def setUp(self):
from airflow.www import app as application
- self.app, self.appbuilder = application.create_app(testing=True)
+ self.app = application.create_app(testing=True)
+ self.appbuilder = self.app.appbuilder # pylint: disable=no-member
self.clear_roles_and_roles()
def tearDown(self):
diff --git a/tests/plugins/test_plugins_manager.py b/tests/plugins/test_plugins_manager.py
index 78da6ba..460e0da 100644
--- a/tests/plugins/test_plugins_manager.py
+++ b/tests/plugins/test_plugins_manager.py
@@ -26,7 +26,8 @@ from airflow.www import app as application
class TestPluginsRBAC(unittest.TestCase):
def setUp(self):
- self.app, self.appbuilder = application.create_app(testing=True)
+ self.app = application.create_app(testing=True)
+ self.appbuilder = self.app.appbuilder # pylint: disable=no-member
def test_flaskappbuilder_views(self):
from tests.plugins.test_plugin import v_appbuilder_package
diff --git a/tests/www/api/experimental/test_dag_runs_endpoint.py b/tests/www/api/experimental/test_dag_runs_endpoint.py
index eef00fe..9a4b257 100644
--- a/tests/www/api/experimental/test_dag_runs_endpoint.py
+++ b/tests/www/api/experimental/test_dag_runs_endpoint.py
@@ -49,7 +49,7 @@ class TestDagRunsEndpoint(unittest.TestCase):
def setUp(self):
super().setUp()
- app, _ = application.create_app(testing=True)
+ app = application.create_app(testing=True)
self.app = app.test_client()
def tearDown(self):
diff --git a/tests/www/api/experimental/test_endpoints.py b/tests/www/api/experimental/test_endpoints.py
index 4d85b5c..bef620d 100644
--- a/tests/www/api/experimental/test_endpoints.py
+++ b/tests/www/api/experimental/test_endpoints.py
@@ -42,7 +42,8 @@ ROOT_FOLDER = os.path.realpath(
class TestBase(unittest.TestCase):
def setUp(self):
- self.app, self.appbuilder = application.create_app(testing=True)
+ self.app = application.create_app(testing=True)
+ self.appbuilder = self.app.appbuilder # pylint: disable=no-member
self.app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///'
self.app.config['SECRET_KEY'] = 'secret_key'
self.app.config['CSRF_ENABLED'] = False
diff --git a/tests/www/api/experimental/test_kerberos_endpoints.py b/tests/www/api/experimental/test_kerberos_endpoints.py
index 0f8e735..43c72aa 100644
--- a/tests/www/api/experimental/test_kerberos_endpoints.py
+++ b/tests/www/api/experimental/test_kerberos_endpoints.py
@@ -39,7 +39,7 @@ class TestApiKerberos(unittest.TestCase):
("kerberos", "keytab"): KRB5_KTNAME,
})
def setUp(self):
- self.app, _ = application.create_app(testing=True)
+ self.app = application.create_app(testing=True)
def test_trigger_dag(self):
with self.app.test_client() as client:
diff --git a/tests/www/test_app.py b/tests/www/test_app.py
index 5bad03f..a16562b 100644
--- a/tests/www/test_app.py
+++ b/tests/www/test_app.py
@@ -38,11 +38,9 @@ class TestApp(unittest.TestCase):
('webserver', 'proxy_fix_x_prefix'): '1'
})
@mock.patch("airflow.www.app.app", None)
- @mock.patch("airflow.www.app.appbuilder", None)
def test_should_respect_proxy_fix(self):
app = application.cached_app(testing=True)
- flask_app = next(iter(app.app.mounts.values()))
- flask_app.url_map.add(Rule("/debug", endpoint="debug"))
+ app.url_map.add(Rule("/debug", endpoint="debug"))
def debug_view():
from flask import request
@@ -55,7 +53,7 @@ class TestApp(unittest.TestCase):
return Response("success")
- flask_app.view_functions['debug'] = debug_view
+ app.view_functions['debug'] = debug_view
new_environ = {
"PATH_INFO": "/debug",
@@ -78,11 +76,9 @@ class TestApp(unittest.TestCase):
('webserver', 'base_url'): 'http://localhost:8080/internal-client',
})
@mock.patch("airflow.www.app.app", None)
- @mock.patch("airflow.www.app.appbuilder", None)
def test_should_respect_base_url_ignore_proxy_headers(self):
app = application.cached_app(testing=True)
- flask_app = next(iter(app.mounts.values()))
- flask_app.url_map.add(Rule("/debug", endpoint="debug"))
+ app.url_map.add(Rule("/debug", endpoint="debug"))
def debug_view():
from flask import request
@@ -95,7 +91,7 @@ class TestApp(unittest.TestCase):
return Response("success")
- flask_app.view_functions['debug'] = debug_view
+ app.view_functions['debug'] = debug_view
new_environ = {
"PATH_INFO": "/internal-client/debug",
@@ -124,11 +120,9 @@ class TestApp(unittest.TestCase):
('webserver', 'proxy_fix_x_prefix'): '1'
})
@mock.patch("airflow.www.app.app", None)
- @mock.patch("airflow.www.app.appbuilder", None)
def test_should_respect_base_url_when_proxy_fix_and_base_url_is_set_up_but_headers_missing(self):
app = application.cached_app(testing=True)
- flask_app = next(iter(app.app.mounts.values()))
- flask_app.url_map.add(Rule("/debug", endpoint="debug"))
+ app.url_map.add(Rule("/debug", endpoint="debug"))
def debug_view():
from flask import request
@@ -140,7 +134,7 @@ class TestApp(unittest.TestCase):
return Response("success")
- flask_app.view_functions['debug'] = debug_view
+ app.view_functions['debug'] = debug_view
new_environ = {
"PATH_INFO": "/internal-client/debug",
@@ -164,11 +158,9 @@ class TestApp(unittest.TestCase):
('webserver', 'proxy_fix_x_prefix'): '1'
})
@mock.patch("airflow.www.app.app", None)
- @mock.patch("airflow.www.app.appbuilder", None)
def test_should_respect_base_url_and_proxy_when_proxy_fix_and_base_url_is_set_up(self):
app = application.cached_app(testing=True)
- flask_app = next(iter(app.app.mounts.values()))
- flask_app.url_map.add(Rule("/debug", endpoint="debug"))
+ app.url_map.add(Rule("/debug", endpoint="debug"))
def debug_view():
from flask import request
@@ -181,7 +173,7 @@ class TestApp(unittest.TestCase):
return Response("success")
- flask_app.view_functions['debug'] = debug_view
+ app.view_functions['debug'] = debug_view
new_environ = {
"PATH_INFO": "/internal-client/debug",
diff --git a/tests/www/test_utils.py b/tests/www/test_utils.py
index f6fb89e..b775d78 100644
--- a/tests/www/test_utils.py
+++ b/tests/www/test_utils.py
@@ -136,8 +136,8 @@ class TestUtils(unittest.TestCase):
def test_task_instance_link(self):
- from airflow.www.app import cached_appbuilder
- with cached_appbuilder(testing=True).app.test_request_context():
+ from airflow.www.app import cached_app
+ with cached_app(testing=True).test_request_context():
html = str(utils.task_instance_link({
'dag_id': '<a&1>',
'task_id': '<b2>',
@@ -150,8 +150,8 @@ class TestUtils(unittest.TestCase):
self.assertNotIn('<b2>', html)
def test_dag_link(self):
- from airflow.www.app import cached_appbuilder
- with cached_appbuilder(testing=True).app.test_request_context():
+ from airflow.www.app import cached_app
+ with cached_app(testing=True).test_request_context():
html = str(utils.dag_link({
'dag_id': '<a&1>',
'execution_date': datetime.now()
@@ -161,8 +161,8 @@ class TestUtils(unittest.TestCase):
self.assertNotIn('<a&1>', html)
def test_dag_run_link(self):
- from airflow.www.app import cached_appbuilder
- with cached_appbuilder(testing=True).app.test_request_context():
+ from airflow.www.app import cached_app
+ with cached_app(testing=True).test_request_context():
html = str(utils.dag_run_link({
'dag_id': '<a&1>',
'run_id': '<b2>',
diff --git a/tests/www/test_views.py b/tests/www/test_views.py
index 4ebd222..603c198 100644
--- a/tests/www/test_views.py
+++ b/tests/www/test_views.py
@@ -116,7 +116,8 @@ class TestBase(unittest.TestCase):
def setUpClass(cls):
settings.configure_orm()
cls.session = settings.Session
- cls.app, cls.appbuilder = application.create_app(testing=True)
+ cls.app = application.create_app(testing=True)
+ cls.appbuilder = cls.app.appbuilder # pylint: disable=no-member
cls.app.config['WTF_CSRF_ENABLED'] = False
cls.app.jinja_env.undefined = jinja2.StrictUndefined
@@ -1043,7 +1044,8 @@ class TestLogView(TestBase):
sys.path.append(self.settings_folder)
with conf_vars({('logging', 'logging_config_class'): 'airflow_local_settings.LOGGING_CONFIG'}):
- self.app, self.appbuilder = application.create_app(testing=True)
+ self.app = application.create_app(testing=True)
+ self.appbuilder = self.app.appbuilder # pylint: disable=no-member
self.app.config['WTF_CSRF_ENABLED'] = False
self.client = self.app.test_client()
settings.configure_orm()