You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2020/11/19 19:46:13 UTC

[airflow] 04/06: Pass SQLAlchemy engine options to FAB based UI (#11395)

This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit d4160a5f8ac7b8a53f52dd5a08ce532d19483192
Author: MichaƂ Misiewicz <mi...@gmail.com>
AuthorDate: Fri Oct 16 19:55:41 2020 +0200

    Pass SQLAlchemy engine options to FAB based UI (#11395)
    
    Co-authored-by: Tomek Urbaszek <tu...@gmail.com>
    (cherry picked from commit 91484b938f0b6f943404f1aeb3e63b61b808cfe9)
---
 airflow/settings.py                  | 57 +++++++++++++++++++-----------------
 airflow/www/app.py                   |  3 ++
 airflow/www_rbac/app.py              |  4 +++
 tests/core/test_sqlalchemy_config.py |  3 +-
 tests/www/test_app.py                |  1 -
 tests/www_rbac/test_app.py           | 26 +++++++++++++++-
 6 files changed, 64 insertions(+), 30 deletions(-)

diff --git a/airflow/settings.py b/airflow/settings.py
index e39c960..0f35dc5 100644
--- a/airflow/settings.py
+++ b/airflow/settings.py
@@ -36,7 +36,6 @@ from sqlalchemy.pool import NullPool
 
 from airflow.configuration import conf, AIRFLOW_HOME, WEBSERVER_CONFIG  # NOQA F401
 from airflow.logging_config import configure_logging
-from airflow.utils.module_loading import import_string
 from airflow.utils.sqlalchemy import setup_event_handlers
 
 log = logging.getLogger(__name__)
@@ -233,12 +232,38 @@ def configure_orm(disable_connection_pool=False):
     log.debug("Setting up DB connection pool (PID %s)" % os.getpid())
     global engine
     global Session
-    engine_args = {}
+    engine_args = prepare_engine_args(disable_connection_pool)
+
+    # Allow the user to specify an encoding for their DB otherwise default
+    # to utf-8 so jobs & users with non-latin1 characters can still use us.
+    engine_args['encoding'] = conf.get('core', 'SQL_ENGINE_ENCODING', fallback='utf-8')
+
+    # For Python2 we get back a newstr and need a str
+    engine_args['encoding'] = engine_args['encoding'].__str__()
+
+    if conf.has_option('core', 'sql_alchemy_connect_args'):
+        connect_args = conf.getimport('core', 'sql_alchemy_connect_args')
+    else:
+        connect_args = {}
+
+    engine = create_engine(SQL_ALCHEMY_CONN, connect_args=connect_args, **engine_args)
+    setup_event_handlers(engine)
+
+    Session = scoped_session(sessionmaker(
+        autocommit=False,
+        autoflush=False,
+        bind=engine,
+        expire_on_commit=False,
+    ))
 
+
+def prepare_engine_args(disable_connection_pool=False):
+    """Prepare SQLAlchemy engine args"""
+    engine_args = {}
     pool_connections = conf.getboolean('core', 'SQL_ALCHEMY_POOL_ENABLED')
     if disable_connection_pool or not pool_connections:
         engine_args['poolclass'] = NullPool
-        log.debug("settings.configure_orm(): Using NullPool")
+        log.debug("settings.prepare_engine_args(): Using NullPool")
     elif 'sqlite' not in SQL_ALCHEMY_CONN:
         # Pool size engine args not supported by sqlite.
         # If no config value is defined for the pool size, select a reasonable value.
@@ -270,35 +295,13 @@ def configure_orm(disable_connection_pool=False):
         # https://docs.sqlalchemy.org/en/13/core/pooling.html#disconnect-handling-pessimistic
         pool_pre_ping = conf.getboolean('core', 'SQL_ALCHEMY_POOL_PRE_PING', fallback=True)
 
-        log.debug("settings.configure_orm(): Using pool settings. pool_size=%d, max_overflow=%d, "
+        log.debug("settings.prepare_engine_args(): Using pool settings. pool_size=%d, max_overflow=%d, "
                   "pool_recycle=%d, pid=%d", pool_size, max_overflow, pool_recycle, os.getpid())
         engine_args['pool_size'] = pool_size
         engine_args['pool_recycle'] = pool_recycle
         engine_args['pool_pre_ping'] = pool_pre_ping
         engine_args['max_overflow'] = max_overflow
-
-    # Allow the user to specify an encoding for their DB otherwise default
-    # to utf-8 so jobs & users with non-latin1 characters can still use
-    # us.
-    engine_args['encoding'] = conf.get('core', 'SQL_ENGINE_ENCODING', fallback='utf-8')
-    # For Python2 we get back a newstr and need a str
-    engine_args['encoding'] = engine_args['encoding'].__str__()
-
-    if conf.has_option('core', 'sql_alchemy_connect_args'):
-        connect_args = import_string(
-            conf.get('core', 'sql_alchemy_connect_args')
-        )
-    else:
-        connect_args = {}
-
-    engine = create_engine(SQL_ALCHEMY_CONN, connect_args=connect_args, **engine_args)
-    setup_event_handlers(engine)
-
-    Session = scoped_session(
-        sessionmaker(autocommit=False,
-                     autoflush=False,
-                     bind=engine,
-                     expire_on_commit=False))
+    return engine_args
 
 
 def dispose_orm():
diff --git a/airflow/www/app.py b/airflow/www/app.py
index b101f45..7d0dae7 100644
--- a/airflow/www/app.py
+++ b/airflow/www/app.py
@@ -17,9 +17,12 @@
 # specific language governing permissions and limitations
 # under the License.
 #
+import datetime
 import logging
 from typing import Any
 
+import flask
+import flask_login
 import six
 from flask import Flask
 from flask_admin import Admin, base
diff --git a/airflow/www_rbac/app.py b/airflow/www_rbac/app.py
index 46ad120..29a364b 100644
--- a/airflow/www_rbac/app.py
+++ b/airflow/www_rbac/app.py
@@ -46,6 +46,7 @@ csrf = CSRFProtect()
 
 log = logging.getLogger(__name__)
 
+
 def create_app(config=None, session=None, testing=False, app_name="Airflow"):
     global app, appbuilder
     app = Flask(__name__)
@@ -76,6 +77,9 @@ def create_app(config=None, session=None, testing=False, app_name="Airflow"):
     if config:
         app.config.from_mapping(config)
 
+    if 'SQLALCHEMY_ENGINE_OPTIONS' not in app.config:
+        app.config['SQLALCHEMY_ENGINE_OPTIONS'] = settings.prepare_engine_args()
+
     csrf.init_app(app)
 
     db = SQLA(app)
diff --git a/tests/core/test_sqlalchemy_config.py b/tests/core/test_sqlalchemy_config.py
index 6fa7ac9..99d5de8 100644
--- a/tests/core/test_sqlalchemy_config.py
+++ b/tests/core/test_sqlalchemy_config.py
@@ -19,6 +19,7 @@
 
 import unittest
 
+from airflow.exceptions import AirflowConfigException
 from sqlalchemy.pool import NullPool
 
 from airflow import settings
@@ -102,6 +103,6 @@ class TestSqlAlchemySettings(unittest.TestCase):
             ('core', 'sql_alchemy_connect_args'): 'does.not.exist',
             ('core', 'sql_alchemy_pool_enabled'): 'False'
         }
-        with self.assertRaises(ImportError):
+        with self.assertRaises(AirflowConfigException):
             with conf_vars(config):
                 settings.configure_orm()
diff --git a/tests/www/test_app.py b/tests/www/test_app.py
index 64255aa..56ec213 100644
--- a/tests/www/test_app.py
+++ b/tests/www/test_app.py
@@ -16,7 +16,6 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
 import unittest
 
 from werkzeug.middleware.proxy_fix import ProxyFix
diff --git a/tests/www_rbac/test_app.py b/tests/www_rbac/test_app.py
index 71d6255..633176e 100644
--- a/tests/www_rbac/test_app.py
+++ b/tests/www_rbac/test_app.py
@@ -16,13 +16,16 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
+import json
 import unittest
 
+import pytest
+import six
 from werkzeug.middleware.proxy_fix import ProxyFix
 
 from airflow.settings import Session
 from airflow.www_rbac import app as application
+from tests.compat import mock
 from tests.test_utils.config import conf_vars
 
 
@@ -56,3 +59,24 @@ class TestApp(unittest.TestCase):
         self.assertEqual(app.wsgi_app.x_host, 5)
         self.assertEqual(app.wsgi_app.x_port, 6)
         self.assertEqual(app.wsgi_app.x_prefix, 7)
+
+    @conf_vars({
+        ('core', 'sql_alchemy_pool_enabled'): 'True',
+        ('core', 'sql_alchemy_pool_size'): '3',
+        ('core', 'sql_alchemy_max_overflow'): '5',
+        ('core', 'sql_alchemy_pool_recycle'): '120',
+        ('core', 'sql_alchemy_pool_pre_ping'): 'True',
+    })
+    @mock.patch("airflow.www_rbac.app.app", None)
+    @pytest.mark.backend("mysql", "postgres")
+    def test_should_set_sqlalchemy_engine_options(self):
+        app = application.cached_appbuilder(testing=True).app
+        engine_params = {
+            'pool_size': 3,
+            'pool_recycle': 120,
+            'pool_pre_ping': True,
+            'max_overflow': 5
+        }
+        if six.PY2:
+            engine_params = json.dumps(engine_params)
+        self.assertEqual(app.config['SQLALCHEMY_ENGINE_OPTIONS'], engine_params)