You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by be...@apache.org on 2022/08/29 18:42:59 UTC

[superset] branch master updated: fix: improve get_db_engine_spec_for_backend (#21171)

This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 8772e2cdb3 fix: improve get_db_engine_spec_for_backend (#21171)
8772e2cdb3 is described below

commit 8772e2cdb3b5a500812e7df12c133f9c9f2e6bad
Author: Beto Dealmeida <ro...@dealmeida.net>
AuthorDate: Mon Aug 29 13:42:42 2022 -0500

    fix: improve get_db_engine_spec_for_backend (#21171)
    
    * fix: improve get_db_engine_spec_for_backend
    
    * Fix tests
    
    * Fix docs
    
    * fix lint
    
    * fix fallback
    
    * Fix engine validation
    
    * Fix test
---
 superset/databases/api.py                          |   4 +-
 superset/databases/commands/validate.py            |  27 +----
 superset/databases/schemas.py                      |  43 +++-----
 superset/db_engine_specs/__init__.py               |  48 ++++++---
 superset/db_engine_specs/base.py                   |  62 ++++++++++-
 superset/db_engine_specs/databricks.py             |  19 ++--
 superset/db_engine_specs/shillelagh.py             |   6 +-
 superset/models/core.py                            |  19 ++--
 tests/integration_tests/databases/api_tests.py     |   2 +-
 .../db_engine_specs/base_engine_spec_tests.py      |   4 +-
 .../db_engine_specs/postgres_tests.py              |   8 +-
 .../databases/schema_tests.py                      | 114 ++++++++++++++-------
 tests/unit_tests/models/core_test.py               |  77 +++++++++++++-
 13 files changed, 306 insertions(+), 127 deletions(-)

diff --git a/superset/databases/api.py b/superset/databases/api.py
index a6160b0d2f..4c617eb720 100644
--- a/superset/databases/api.py
+++ b/superset/databases/api.py
@@ -1083,8 +1083,8 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
                 "preferred": engine_spec.engine_name in preferred_databases,
             }
 
-            if hasattr(engine_spec, "default_driver"):
-                payload["default_driver"] = engine_spec.default_driver  # type: ignore
+            if engine_spec.default_driver:
+                payload["default_driver"] = engine_spec.default_driver
 
             # show configuration parameters for DBs that support it
             if (
diff --git a/superset/databases/commands/validate.py b/superset/databases/commands/validate.py
index a9f1633a18..e9fe5eaf0c 100644
--- a/superset/databases/commands/validate.py
+++ b/superset/databases/commands/validate.py
@@ -29,8 +29,7 @@ from superset.databases.commands.exceptions import (
 )
 from superset.databases.dao import DatabaseDAO
 from superset.databases.utils import make_url_safe
-from superset.db_engine_specs import get_engine_specs
-from superset.db_engine_specs.base import BasicParametersMixin
+from superset.db_engine_specs import get_engine_spec
 from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
 from superset.extensions import event_logger
 from superset.models.core import Database
@@ -45,25 +44,13 @@ class ValidateDatabaseParametersCommand(BaseCommand):
 
     def run(self) -> None:
         engine = self._properties["engine"]
-        engine_specs = get_engine_specs()
+        driver = self._properties.get("driver")
 
         if engine in BYPASS_VALIDATION_ENGINES:
             # Skip engines that are only validated onCreate
             return
 
-        if engine not in engine_specs:
-            raise InvalidEngineError(
-                SupersetError(
-                    message=__(
-                        'Engine "%(engine)s" is not a valid engine.',
-                        engine=engine,
-                    ),
-                    error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
-                    level=ErrorLevel.ERROR,
-                    extra={"allowed": list(engine_specs), "provided": engine},
-                ),
-            )
-        engine_spec = engine_specs[engine]
+        engine_spec = get_engine_spec(engine, driver)
         if not hasattr(engine_spec, "parameters_schema"):
             raise InvalidEngineError(
                 SupersetError(
@@ -73,14 +60,6 @@ class ValidateDatabaseParametersCommand(BaseCommand):
                     ),
                     error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
                     level=ErrorLevel.ERROR,
-                    extra={
-                        "allowed": [
-                            name
-                            for name, engine_spec in engine_specs.items()
-                            if issubclass(engine_spec, BasicParametersMixin)
-                        ],
-                        "provided": engine,
-                    },
                 ),
             )
 
diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py
index aa88822a85..b6a0ab6983 100644
--- a/superset/databases/schemas.py
+++ b/superset/databases/schemas.py
@@ -16,7 +16,7 @@
 # under the License.
 import inspect
 import json
-from typing import Any, Dict, Optional, Type
+from typing import Any, Dict
 
 from flask import current_app
 from flask_babel import lazy_gettext as _
@@ -28,7 +28,7 @@ from sqlalchemy import MetaData
 from superset import db
 from superset.databases.commands.exceptions import DatabaseInvalidError
 from superset.databases.utils import make_url_safe
-from superset.db_engine_specs import BaseEngineSpec, get_engine_specs
+from superset.db_engine_specs import get_engine_spec
 from superset.exceptions import CertificateException, SupersetSecurityException
 from superset.models.core import ConfigurationMethod, Database, PASSWORD_MASK
 from superset.security.analytics_db_safety import check_sqlalchemy_uri
@@ -150,7 +150,7 @@ def sqlalchemy_uri_validator(value: str) -> str:
             [
                 _(
                     "Invalid connection string, a valid string usually follows: "
-                    "driver://user:password@database-host/database-name"
+                    "backend+driver://user:password@database-host/database-name"
                 )
             ]
         ) from ex
@@ -231,6 +231,7 @@ class DatabaseParametersSchemaMixin:  # pylint: disable=too-few-public-methods
     """
 
     engine = fields.String(allow_none=True, description="SQLAlchemy engine to use")
+    driver = fields.String(allow_none=True, description="SQLAlchemy driver to use")
     parameters = fields.Dict(
         keys=fields.String(),
         values=fields.Raw(),
@@ -262,10 +263,20 @@ class DatabaseParametersSchemaMixin:  # pylint: disable=too-few-public-methods
             or parameters.pop("engine", None)
             or data.pop("backend", None)
         )
+        driver = data.pop("driver", None)
 
         configuration_method = data.get("configuration_method")
         if configuration_method == ConfigurationMethod.DYNAMIC_FORM:
-            engine_spec = get_engine_spec(engine)
+            if not engine:
+                raise ValidationError(
+                    [
+                        _(
+                            "An engine must be specified when passing "
+                            "individual parameters to a database."
+                        )
+                    ]
+                )
+            engine_spec = get_engine_spec(engine, driver)
 
             if not hasattr(engine_spec, "build_sqlalchemy_uri") or not hasattr(
                 engine_spec, "parameters_schema"
@@ -295,34 +306,12 @@ class DatabaseParametersSchemaMixin:  # pylint: disable=too-few-public-methods
         return data
 
 
-def get_engine_spec(engine: Optional[str]) -> Type[BaseEngineSpec]:
-    if not engine:
-        raise ValidationError(
-            [
-                _(
-                    "An engine must be specified when passing "
-                    "individual parameters to a database."
-                )
-            ]
-        )
-    engine_specs = get_engine_specs()
-    if engine not in engine_specs:
-        raise ValidationError(
-            [
-                _(
-                    'Engine "%(engine)s" is not a valid engine.',
-                    engine=engine,
-                )
-            ]
-        )
-    return engine_specs[engine]
-
-
 class DatabaseValidateParametersSchema(Schema):
     class Meta:  # pylint: disable=too-few-public-methods
         unknown = EXCLUDE
 
     engine = fields.String(required=True, description="SQLAlchemy engine to use")
+    driver = fields.String(allow_none=True, description="SQLAlchemy driver to use")
     parameters = fields.Dict(
         keys=fields.String(),
         values=fields.Raw(allow_none=True),
diff --git a/superset/db_engine_specs/__init__.py b/superset/db_engine_specs/__init__.py
index dac7001995..29e4877337 100644
--- a/superset/db_engine_specs/__init__.py
+++ b/superset/db_engine_specs/__init__.py
@@ -33,27 +33,34 @@ import pkgutil
 from collections import defaultdict
 from importlib import import_module
 from pathlib import Path
-from typing import Any, Dict, List, Set, Type
+from typing import Any, Dict, List, Optional, Set, Type
 
 import sqlalchemy.databases
 import sqlalchemy.dialects
 from pkg_resources import iter_entry_points
 from sqlalchemy.engine.default import DefaultDialect
+from sqlalchemy.engine.url import URL
 
 from superset.db_engine_specs.base import BaseEngineSpec
 
 logger = logging.getLogger(__name__)
 
 
-def is_engine_spec(attr: Any) -> bool:
+def is_engine_spec(obj: Any) -> bool:
+    """
+    Return true if a given object is a DB engine spec.
+    """
     return (
-        inspect.isclass(attr)
-        and issubclass(attr, BaseEngineSpec)
-        and attr != BaseEngineSpec
+        inspect.isclass(obj)
+        and issubclass(obj, BaseEngineSpec)
+        and obj != BaseEngineSpec
     )
 
 
 def load_engine_specs() -> List[Type[BaseEngineSpec]]:
+    """
+    Load all engine specs, native and 3rd party.
+    """
     engine_specs: List[Type[BaseEngineSpec]] = []
 
     # load standard engines
@@ -78,20 +85,31 @@ def load_engine_specs() -> List[Type[BaseEngineSpec]]:
     return engine_specs
 
 
-def get_engine_specs() -> Dict[str, Type[BaseEngineSpec]]:
+def get_engine_spec(backend: str, driver: Optional[str] = None) -> Type[BaseEngineSpec]:
+    """
+    Return the DB engine spec associated with a given SQLAlchemy URL.
+
+    Note that if a driver is not specified the function returns the first DB engine spec
+    that supports the backend. Also, if a driver is specified but no DB engine explicitly
+    supporting that driver exists then a backend-only match is done, in order to allow new
+    drivers to work with Superset even if they are not listed in the DB engine spec
+    drivers.
+    """
     engine_specs = load_engine_specs()
 
-    # build map from name/alias -> spec
-    engine_specs_map: Dict[str, Type[BaseEngineSpec]] = {}
-    for engine_spec in engine_specs:
-        names = [engine_spec.engine]
-        if engine_spec.engine_aliases:
-            names.extend(engine_spec.engine_aliases)
+    if driver is not None:
+        for engine_spec in engine_specs:
+            if engine_spec.supports_backend(backend, driver):
+                return engine_spec
 
-        for name in names:
-            engine_specs_map[name] = engine_spec
+    # check ignoring the driver, in order to support new drivers; this will return a
+    # random DB engine spec that supports the engine
+    for engine_spec in engine_specs:
+        if engine_spec.supports_backend(backend):
+            return engine_spec
 
-    return engine_specs_map
+    # default to the generic DB engine spec
+    return BaseEngineSpec
 
 
 # there's a mismatch between the dialect name reported by the driver in these
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 368770e261..1bf2f4a3f7 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -183,9 +183,15 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
                                        having to add the same aggregation in SELECT.
     """
 
+    engine_name: Optional[str] = None  # for user messages, overridden in child classes
+
+    # These attributes map the DB engine spec to one or more SQLAlchemy dialects/drivers;
+    # see the ``supports_url`` and ``supports_backend`` methods below.
     engine = "base"  # str as defined in sqlalchemy.engine.engine
     engine_aliases: Set[str] = set()
-    engine_name: Optional[str] = None  # for user messages, overridden in child classes
+    drivers: Dict[str, str] = {}
+    default_driver: Optional[str] = None
+
     _date_trunc_functions: Dict[str, str] = {}
     _time_grain_expressions: Dict[Optional[str], str] = {}
     column_type_mappings: Tuple[ColumnTypeMapping, ...] = (
@@ -355,6 +361,58 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
         Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]
     ] = {}
 
+    @classmethod
+    def supports_url(cls, url: URL) -> bool:
+        """
+        Returns true if the DB engine spec supports a given SQLAlchemy URL.
+
+        As an example, if a given DB engine spec has:
+
+            class PostgresDBEngineSpec:
+                engine = "postgresql"
+                engine_aliases = "postgres"
+                drivers = {
+                    "psycopg2": "The default Postgres driver",
+                    "asyncpg": "An asynchronous Postgres driver",
+                }
+
+        It would be used for all the following SQLAlchemy URIs:
+
+            - postgres://user:password@host/db
+            - postgresql://user:password@host/db
+            - postgres+asyncpg://user:password@host/db
+            - postgres+psycopg2://user:password@host/db
+            - postgresql+asyncpg://user:password@host/db
+            - postgresql+psycopg2://user:password@host/db
+
+        Note that SQLAlchemy has a default driver even if one is not specified:
+
+            >>> from sqlalchemy.engine.url import make_url
+            >>> make_url('postgres://').get_driver_name()
+            'psycopg2'
+
+        """
+        backend = url.get_backend_name()
+        driver = url.get_driver_name()
+        return cls.supports_backend(backend, driver)
+
+    @classmethod
+    def supports_backend(cls, backend: str, driver: Optional[str] = None) -> bool:
+        """
+        Returns true if the DB engine spec supports a given SQLAlchemy backend/driver.
+        """
+        # check the backend first
+        if backend != cls.engine and backend not in cls.engine_aliases:
+            return False
+
+        # originally DB engine specs didn't declare any drivers and the check was made
+        # only on the engine; if that's the case, ignore the driver for backwards
+        # compatibility
+        if not cls.drivers or driver is None:
+            return True
+
+        return driver in cls.drivers
+
     @classmethod
     def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]:
         """
@@ -394,7 +452,7 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
     @classmethod
     def get_text_clause(cls, clause: str) -> TextClause:
         """
-        SQLALchemy wrapper to ensure text clauses are escaped properly
+        SQLAlchemy wrapper to ensure text clauses are escaped properly
 
         :param clause: string clause with potentially unescaped characters
         :return: text clause with escaped characters
diff --git a/superset/db_engine_specs/databricks.py b/superset/db_engine_specs/databricks.py
index 79718c93f6..90d90b9448 100644
--- a/superset/db_engine_specs/databricks.py
+++ b/superset/db_engine_specs/databricks.py
@@ -47,18 +47,23 @@ time_grain_expressions = {
 
 
 class DatabricksHiveEngineSpec(HiveEngineSpec):
-    engine = "databricks"
     engine_name = "Databricks Interactive Cluster"
-    driver = "pyhive"
+
+    engine = "databricks"
+    drivers = {"pyhive": "Hive driver for Interactive Cluster"}
+    default_driver = "pyhive"
+
     _show_functions_column = "function"
 
     _time_grain_expressions = time_grain_expressions
 
 
 class DatabricksODBCEngineSpec(BaseEngineSpec):
-    engine = "databricks"
     engine_name = "Databricks SQL Endpoint"
-    driver = "pyodbc"
+
+    engine = "databricks"
+    drivers = {"pyodbc": "ODBC driver for SQL endpoint"}
+    default_driver = "pyodbc"
 
     _time_grain_expressions = time_grain_expressions
 
@@ -74,9 +79,11 @@ class DatabricksODBCEngineSpec(BaseEngineSpec):
 
 
 class DatabricksNativeEngineSpec(DatabricksODBCEngineSpec):
-    engine = "databricks"
     engine_name = "Databricks Native Connector"
-    driver = "connector"
+
+    engine = "databricks"
+    drivers = {"connector": "Native all-purpose driver"}
+    default_driver = "connector"
 
     @staticmethod
     def get_extra_params(database: "Database") -> Dict[str, Any]:
diff --git a/superset/db_engine_specs/shillelagh.py b/superset/db_engine_specs/shillelagh.py
index c6e6f618c7..3730122448 100644
--- a/superset/db_engine_specs/shillelagh.py
+++ b/superset/db_engine_specs/shillelagh.py
@@ -20,7 +20,11 @@ from superset.db_engine_specs.sqlite import SqliteEngineSpec
 class ShillelaghEngineSpec(SqliteEngineSpec):
     """Engine for shillelagh"""
 
-    engine = "shillelagh"
     engine_name = "Shillelagh"
+    engine = "shillelagh"
+    drivers = {"apsw": "SQLite driver"}
+    default_driver = "apsw"
+    sqlalchemy_uri_placeholder = "shillelagh://"
+
     allows_joins = True
     allows_subqueries = True
diff --git a/superset/models/core.py b/superset/models/core.py
index b5a4aa6537..ec7ec79321 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -46,7 +46,7 @@ from sqlalchemy import (
 from sqlalchemy.engine import Connection, Dialect, Engine
 from sqlalchemy.engine.reflection import Inspector
 from sqlalchemy.engine.url import URL
-from sqlalchemy.exc import ArgumentError
+from sqlalchemy.exc import ArgumentError, NoSuchModuleError
 from sqlalchemy.ext.hybrid import hybrid_property
 from sqlalchemy.orm import relationship
 from sqlalchemy.pool import NullPool
@@ -635,15 +635,20 @@ class Database(
 
     @property
     def db_engine_spec(self) -> Type[db_engine_specs.BaseEngineSpec]:
-        return self.get_db_engine_spec_for_backend(self.backend)
+        url = make_url_safe(self.sqlalchemy_uri_decrypted)
+        return self.get_db_engine_spec(url)
 
     @classmethod
     @memoized
-    def get_db_engine_spec_for_backend(
-        cls, backend: str
-    ) -> Type[db_engine_specs.BaseEngineSpec]:
-        engines = db_engine_specs.get_engine_specs()
-        return engines.get(backend, db_engine_specs.BaseEngineSpec)
+    def get_db_engine_spec(cls, url: URL) -> Type[db_engine_specs.BaseEngineSpec]:
+        backend = url.get_backend_name()
+        try:
+            driver = url.get_driver_name()
+        except NoSuchModuleError:
+            # can't load the driver, fallback for backwards compatibility
+            driver = None
+
+        return db_engine_specs.get_engine_spec(backend, driver)
 
     def grains(self) -> Tuple[TimeGrain, ...]:
         """Defines time granularity database-specific expressions.
diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py
index 8ff12b2406..b53418fb16 100644
--- a/tests/integration_tests/databases/api_tests.py
+++ b/tests/integration_tests/databases/api_tests.py
@@ -1425,7 +1425,7 @@ class TestDatabaseApi(SupersetTestCase):
         expected_response = {
             "errors": [
                 {
-                    "message": "Could not load database driver: AzureSynapseSpec",
+                    "message": "Could not load database driver: MssqlEngineSpec",
                     "error_type": "GENERIC_COMMAND_ERROR",
                     "level": "warning",
                     "extra": {
diff --git a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py
index 07f9bfcf31..f998444f31 100644
--- a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py
+++ b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py
@@ -20,7 +20,7 @@ from unittest import mock
 import pytest
 
 from superset.connectors.sqla.models import TableColumn
-from superset.db_engine_specs import get_engine_specs
+from superset.db_engine_specs import load_engine_specs
 from superset.db_engine_specs.base import (
     BaseEngineSpec,
     BasicParametersMixin,
@@ -195,7 +195,7 @@ class TestDbEngineSpecs(TestDbEngineSpec):
     def test_engine_time_grain_validity(self):
         time_grains = set(builtin_time_grains.keys())
         # loop over all subclasses of BaseEngineSpec
-        for engine in get_engine_specs().values():
+        for engine in load_engine_specs():
             if engine is not BaseEngineSpec:
                 # make sure time grain functions have been defined
                 self.assertGreater(len(engine.get_time_grain_expressions()), 0)
diff --git a/tests/integration_tests/db_engine_specs/postgres_tests.py b/tests/integration_tests/db_engine_specs/postgres_tests.py
index e6eb4fc1d1..79a307a488 100644
--- a/tests/integration_tests/db_engine_specs/postgres_tests.py
+++ b/tests/integration_tests/db_engine_specs/postgres_tests.py
@@ -20,7 +20,7 @@ from unittest import mock
 from sqlalchemy import column, literal_column
 from sqlalchemy.dialects import postgresql
 
-from superset.db_engine_specs import get_engine_specs
+from superset.db_engine_specs import load_engine_specs
 from superset.db_engine_specs.postgres import PostgresEngineSpec
 from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
 from superset.models.sql_lab import Query
@@ -137,7 +137,11 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
         """
         DB Eng Specs (postgres): Test "postgres" in engine spec
         """
-        self.assertIn("postgres", get_engine_specs())
+        backends = set()
+        for engine in load_engine_specs():
+            backends.add(engine.engine)
+            backends.update(engine.engine_aliases)
+        assert "postgres" in backends
 
     def test_extras_without_ssl(self):
         db = mock.Mock()
diff --git a/tests/integration_tests/databases/schema_tests.py b/tests/unit_tests/databases/schema_tests.py
similarity index 57%
rename from tests/integration_tests/databases/schema_tests.py
rename to tests/unit_tests/databases/schema_tests.py
index 1f8ca067f6..58a1f6389d 100644
--- a/tests/integration_tests/databases/schema_tests.py
+++ b/tests/unit_tests/databases/schema_tests.py
@@ -15,31 +15,59 @@
 # specific language governing permissions and limitations
 # under the License.
 
-from unittest import mock
+# pylint: disable=import-outside-toplevel, invalid-name, unused-argument, redefined-outer-name
 
+from typing import TYPE_CHECKING
+
+import pytest
 from marshmallow import fields, Schema, ValidationError
+from pytest_mock import MockFixture
+
+if TYPE_CHECKING:
+    from superset.databases.schemas import DatabaseParametersSchemaMixin
+    from superset.db_engine_specs.base import BasicParametersMixin
 
-from superset.databases.schemas import DatabaseParametersSchemaMixin
-from superset.db_engine_specs.base import BasicParametersMixin
-from superset.models.core import ConfigurationMethod
 
+# pylint: disable=too-few-public-methods
+class InvalidEngine:
+    """
+    An invalid DB engine spec.
+    """
 
-class DummySchema(Schema, DatabaseParametersSchemaMixin):
-    sqlalchemy_uri = fields.String()
 
+@pytest.fixture
+def dummy_schema() -> "DatabaseParametersSchemaMixin":
+    """
+    Fixture providing a dummy schema.
+    """
+    from superset.databases.schemas import DatabaseParametersSchemaMixin
 
-class DummyEngine(BasicParametersMixin):
-    engine = "dummy"
-    default_driver = "dummy"
+    class DummySchema(Schema, DatabaseParametersSchemaMixin):
+        sqlalchemy_uri = fields.String()
 
+    return DummySchema()
+
+
+@pytest.fixture
+def dummy_engine(mocker: MockFixture) -> None:
+    """
+    Fixture proving a dummy DB engine spec.
+    """
+    from superset.db_engine_specs.base import BasicParametersMixin
+
+    class DummyEngine(BasicParametersMixin):
+        engine = "dummy"
+        default_driver = "dummy"
+
+    mocker.patch("superset.databases.schemas.get_engine_spec", return_value=DummyEngine)
 
-class InvalidEngine:
-    pass
 
+def test_database_parameters_schema_mixin(
+    dummy_engine: None,
+    dummy_schema: "Schema",
+) -> None:
+    from superset.models.core import ConfigurationMethod
 
-@mock.patch("superset.databases.schemas.get_engine_specs")
-def test_database_parameters_schema_mixin(get_engine_specs):
-    get_engine_specs.return_value = {"dummy_engine": DummyEngine}
     payload = {
         "engine": "dummy_engine",
         "configuration_method": ConfigurationMethod.DYNAMIC_FORM,
@@ -51,15 +79,18 @@ def test_database_parameters_schema_mixin(get_engine_specs):
             "database": "dbname",
         },
     }
-    schema = DummySchema()
-    result = schema.load(payload)
+    result = dummy_schema.load(payload)
     assert result == {
         "configuration_method": ConfigurationMethod.DYNAMIC_FORM,
         "sqlalchemy_uri": "dummy+dummy://username:password@localhost:12345/dbname",
     }
 
 
-def test_database_parameters_schema_mixin_no_engine():
+def test_database_parameters_schema_mixin_no_engine(
+    dummy_schema: "Schema",
+) -> None:
+    from superset.models.core import ConfigurationMethod
+
     payload = {
         "configuration_method": ConfigurationMethod.DYNAMIC_FORM,
         "parameters": {
@@ -67,23 +98,28 @@ def test_database_parameters_schema_mixin_no_engine():
             "password": "password",
             "host": "localhost",
             "port": 12345,
-            "dbname": "dbname",
+            "database": "dbname",
         },
     }
-    schema = DummySchema()
     try:
-        schema.load(payload)
+        dummy_schema.load(payload)
     except ValidationError as err:
         assert err.messages == {
             "_schema": [
-                "An engine must be specified when passing individual parameters to a database."
+                (
+                    "An engine must be specified when passing individual parameters to "
+                    "a database."
+                ),
             ]
         }
 
 
-@mock.patch("superset.databases.schemas.get_engine_specs")
-def test_database_parameters_schema_mixin_invalid_engine(get_engine_specs):
-    get_engine_specs.return_value = {}
+def test_database_parameters_schema_mixin_invalid_engine(
+    dummy_engine: None,
+    dummy_schema: "Schema",
+) -> None:
+    from superset.models.core import ConfigurationMethod
+
     payload = {
         "engine": "dummy_engine",
         "configuration_method": ConfigurationMethod.DYNAMIC_FORM,
@@ -92,21 +128,24 @@ def test_database_parameters_schema_mixin_invalid_engine(get_engine_specs):
             "password": "password",
             "host": "localhost",
             "port": 12345,
-            "dbname": "dbname",
+            "database": "dbname",
         },
     }
-    schema = DummySchema()
     try:
-        schema.load(payload)
+        dummy_schema.load(payload)
     except ValidationError as err:
+        print(err.messages)
         assert err.messages == {
             "_schema": ['Engine "dummy_engine" is not a valid engine.']
         }
 
 
-@mock.patch("superset.databases.schemas.get_engine_specs")
-def test_database_parameters_schema_no_mixin(get_engine_specs):
-    get_engine_specs.return_value = {"invalid_engine": InvalidEngine}
+def test_database_parameters_schema_no_mixin(
+    dummy_engine: None,
+    dummy_schema: "Schema",
+) -> None:
+    from superset.models.core import ConfigurationMethod
+
     payload = {
         "engine": "invalid_engine",
         "configuration_method": ConfigurationMethod.DYNAMIC_FORM,
@@ -118,9 +157,8 @@ def test_database_parameters_schema_no_mixin(get_engine_specs):
             "database": "dbname",
         },
     }
-    schema = DummySchema()
     try:
-        schema.load(payload)
+        dummy_schema.load(payload)
     except ValidationError as err:
         assert err.messages == {
             "_schema": [
@@ -132,9 +170,12 @@ def test_database_parameters_schema_no_mixin(get_engine_specs):
         }
 
 
-@mock.patch("superset.databases.schemas.get_engine_specs")
-def test_database_parameters_schema_mixin_invalid_type(get_engine_specs):
-    get_engine_specs.return_value = {"dummy_engine": DummyEngine}
+def test_database_parameters_schema_mixin_invalid_type(
+    dummy_engine: None,
+    dummy_schema: "Schema",
+) -> None:
+    from superset.models.core import ConfigurationMethod
+
     payload = {
         "engine": "dummy_engine",
         "configuration_method": ConfigurationMethod.DYNAMIC_FORM,
@@ -146,8 +187,7 @@ def test_database_parameters_schema_mixin_invalid_type(get_engine_specs):
             "database": "dbname",
         },
     }
-    schema = DummySchema()
     try:
-        schema.load(payload)
+        dummy_schema.load(payload)
     except ValidationError as err:
         assert err.messages == {"port": ["Not a valid integer."]}
diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py
index 3338ddcb61..5eb60dc6f9 100644
--- a/tests/unit_tests/models/core_test.py
+++ b/tests/unit_tests/models/core_test.py
@@ -59,7 +59,7 @@ def test_get_metrics(mocker: MockFixture) -> None:
                 },
             ]
 
-    database.get_db_engine_spec_for_backend = mocker.MagicMock(  # type: ignore
+    database.get_db_engine_spec = mocker.MagicMock(  # type: ignore
         return_value=CustomSqliteEngineSpec
     )
     assert database.get_metrics("table") == [
@@ -70,3 +70,78 @@ def test_get_metrics(mocker: MockFixture) -> None:
             "verbose_name": "COUNT(DISTINCT user_id)",
         },
     ]
+
+
+def test_get_db_engine_spec(mocker: MockFixture) -> None:
+    """
+    Tests for ``get_db_engine_spec``.
+    """
+    from superset.db_engine_specs import BaseEngineSpec
+    from superset.models.core import Database
+
+    # pylint: disable=abstract-method
+    class PostgresDBEngineSpec(BaseEngineSpec):
+        """
+        A DB engine spec with drivers and a default driver.
+        """
+
+        engine = "postgresql"
+        engine_aliases = {"postgres"}
+        drivers = {
+            "psycopg2": "The default Postgres driver",
+            "asyncpg": "An async Postgres driver",
+        }
+        default_driver = "psycopg2"
+
+    # pylint: disable=abstract-method
+    class OldDBEngineSpec(BaseEngineSpec):
+        """
+        And old DB engine spec without drivers nor a default driver.
+        """
+
+        engine = "mysql"
+
+    load_engine_specs = mocker.patch("superset.db_engine_specs.load_engine_specs")
+    load_engine_specs.return_value = [
+        PostgresDBEngineSpec,
+        OldDBEngineSpec,
+    ]
+
+    assert (
+        Database(database_name="db", sqlalchemy_uri="postgresql://").db_engine_spec
+        == PostgresDBEngineSpec
+    )
+    assert (
+        Database(
+            database_name="db", sqlalchemy_uri="postgresql+psycopg2://"
+        ).db_engine_spec
+        == PostgresDBEngineSpec
+    )
+    assert (
+        Database(
+            database_name="db", sqlalchemy_uri="postgresql+asyncpg://"
+        ).db_engine_spec
+        == PostgresDBEngineSpec
+    )
+    assert (
+        Database(
+            database_name="db", sqlalchemy_uri="postgresql+fancynewdriver://"
+        ).db_engine_spec
+        == PostgresDBEngineSpec
+    )
+    assert (
+        Database(database_name="db", sqlalchemy_uri="mysql://").db_engine_spec
+        == OldDBEngineSpec
+    )
+    assert (
+        Database(
+            database_name="db", sqlalchemy_uri="mysql+mysqlconnector://"
+        ).db_engine_spec
+        == OldDBEngineSpec
+    )
+    assert (
+        Database(
+            database_name="db", sqlalchemy_uri="mysql+fancynewdriver://"
+        ).db_engine_spec
+        == OldDBEngineSpec
+    )