You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by be...@apache.org on 2022/08/29 18:42:59 UTC
[superset] branch master updated: fix: improve get_db_engine_spec_for_backend (#21171)
This is an automated email from the ASF dual-hosted git repository.
beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new 8772e2cdb3 fix: improve get_db_engine_spec_for_backend (#21171)
8772e2cdb3 is described below
commit 8772e2cdb3b5a500812e7df12c133f9c9f2e6bad
Author: Beto Dealmeida <ro...@dealmeida.net>
AuthorDate: Mon Aug 29 13:42:42 2022 -0500
fix: improve get_db_engine_spec_for_backend (#21171)
* fix: improve get_db_engine_spec_for_backend
* Fix tests
* Fix docs
* fix lint
* fix fallback
* Fix engine validation
* Fix test
---
superset/databases/api.py | 4 +-
superset/databases/commands/validate.py | 27 +----
superset/databases/schemas.py | 43 +++-----
superset/db_engine_specs/__init__.py | 48 ++++++---
superset/db_engine_specs/base.py | 62 ++++++++++-
superset/db_engine_specs/databricks.py | 19 ++--
superset/db_engine_specs/shillelagh.py | 6 +-
superset/models/core.py | 19 ++--
tests/integration_tests/databases/api_tests.py | 2 +-
.../db_engine_specs/base_engine_spec_tests.py | 4 +-
.../db_engine_specs/postgres_tests.py | 8 +-
.../databases/schema_tests.py | 114 ++++++++++++++-------
tests/unit_tests/models/core_test.py | 77 +++++++++++++-
13 files changed, 306 insertions(+), 127 deletions(-)
diff --git a/superset/databases/api.py b/superset/databases/api.py
index a6160b0d2f..4c617eb720 100644
--- a/superset/databases/api.py
+++ b/superset/databases/api.py
@@ -1083,8 +1083,8 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
"preferred": engine_spec.engine_name in preferred_databases,
}
- if hasattr(engine_spec, "default_driver"):
- payload["default_driver"] = engine_spec.default_driver # type: ignore
+ if engine_spec.default_driver:
+ payload["default_driver"] = engine_spec.default_driver
# show configuration parameters for DBs that support it
if (
diff --git a/superset/databases/commands/validate.py b/superset/databases/commands/validate.py
index a9f1633a18..e9fe5eaf0c 100644
--- a/superset/databases/commands/validate.py
+++ b/superset/databases/commands/validate.py
@@ -29,8 +29,7 @@ from superset.databases.commands.exceptions import (
)
from superset.databases.dao import DatabaseDAO
from superset.databases.utils import make_url_safe
-from superset.db_engine_specs import get_engine_specs
-from superset.db_engine_specs.base import BasicParametersMixin
+from superset.db_engine_specs import get_engine_spec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.extensions import event_logger
from superset.models.core import Database
@@ -45,25 +44,13 @@ class ValidateDatabaseParametersCommand(BaseCommand):
def run(self) -> None:
engine = self._properties["engine"]
- engine_specs = get_engine_specs()
+ driver = self._properties.get("driver")
if engine in BYPASS_VALIDATION_ENGINES:
# Skip engines that are only validated onCreate
return
- if engine not in engine_specs:
- raise InvalidEngineError(
- SupersetError(
- message=__(
- 'Engine "%(engine)s" is not a valid engine.',
- engine=engine,
- ),
- error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
- level=ErrorLevel.ERROR,
- extra={"allowed": list(engine_specs), "provided": engine},
- ),
- )
- engine_spec = engine_specs[engine]
+ engine_spec = get_engine_spec(engine, driver)
if not hasattr(engine_spec, "parameters_schema"):
raise InvalidEngineError(
SupersetError(
@@ -73,14 +60,6 @@ class ValidateDatabaseParametersCommand(BaseCommand):
),
error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
level=ErrorLevel.ERROR,
- extra={
- "allowed": [
- name
- for name, engine_spec in engine_specs.items()
- if issubclass(engine_spec, BasicParametersMixin)
- ],
- "provided": engine,
- },
),
)
diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py
index aa88822a85..b6a0ab6983 100644
--- a/superset/databases/schemas.py
+++ b/superset/databases/schemas.py
@@ -16,7 +16,7 @@
# under the License.
import inspect
import json
-from typing import Any, Dict, Optional, Type
+from typing import Any, Dict
from flask import current_app
from flask_babel import lazy_gettext as _
@@ -28,7 +28,7 @@ from sqlalchemy import MetaData
from superset import db
from superset.databases.commands.exceptions import DatabaseInvalidError
from superset.databases.utils import make_url_safe
-from superset.db_engine_specs import BaseEngineSpec, get_engine_specs
+from superset.db_engine_specs import get_engine_spec
from superset.exceptions import CertificateException, SupersetSecurityException
from superset.models.core import ConfigurationMethod, Database, PASSWORD_MASK
from superset.security.analytics_db_safety import check_sqlalchemy_uri
@@ -150,7 +150,7 @@ def sqlalchemy_uri_validator(value: str) -> str:
[
_(
"Invalid connection string, a valid string usually follows: "
- "driver://user:password@database-host/database-name"
+ "backend+driver://user:password@database-host/database-name"
)
]
) from ex
@@ -231,6 +231,7 @@ class DatabaseParametersSchemaMixin: # pylint: disable=too-few-public-methods
"""
engine = fields.String(allow_none=True, description="SQLAlchemy engine to use")
+ driver = fields.String(allow_none=True, description="SQLAlchemy driver to use")
parameters = fields.Dict(
keys=fields.String(),
values=fields.Raw(),
@@ -262,10 +263,20 @@ class DatabaseParametersSchemaMixin: # pylint: disable=too-few-public-methods
or parameters.pop("engine", None)
or data.pop("backend", None)
)
+ driver = data.pop("driver", None)
configuration_method = data.get("configuration_method")
if configuration_method == ConfigurationMethod.DYNAMIC_FORM:
- engine_spec = get_engine_spec(engine)
+ if not engine:
+ raise ValidationError(
+ [
+ _(
+ "An engine must be specified when passing "
+ "individual parameters to a database."
+ )
+ ]
+ )
+ engine_spec = get_engine_spec(engine, driver)
if not hasattr(engine_spec, "build_sqlalchemy_uri") or not hasattr(
engine_spec, "parameters_schema"
@@ -295,34 +306,12 @@ class DatabaseParametersSchemaMixin: # pylint: disable=too-few-public-methods
return data
-def get_engine_spec(engine: Optional[str]) -> Type[BaseEngineSpec]:
- if not engine:
- raise ValidationError(
- [
- _(
- "An engine must be specified when passing "
- "individual parameters to a database."
- )
- ]
- )
- engine_specs = get_engine_specs()
- if engine not in engine_specs:
- raise ValidationError(
- [
- _(
- 'Engine "%(engine)s" is not a valid engine.',
- engine=engine,
- )
- ]
- )
- return engine_specs[engine]
-
-
class DatabaseValidateParametersSchema(Schema):
class Meta: # pylint: disable=too-few-public-methods
unknown = EXCLUDE
engine = fields.String(required=True, description="SQLAlchemy engine to use")
+ driver = fields.String(allow_none=True, description="SQLAlchemy driver to use")
parameters = fields.Dict(
keys=fields.String(),
values=fields.Raw(allow_none=True),
diff --git a/superset/db_engine_specs/__init__.py b/superset/db_engine_specs/__init__.py
index dac7001995..29e4877337 100644
--- a/superset/db_engine_specs/__init__.py
+++ b/superset/db_engine_specs/__init__.py
@@ -33,27 +33,34 @@ import pkgutil
from collections import defaultdict
from importlib import import_module
from pathlib import Path
-from typing import Any, Dict, List, Set, Type
+from typing import Any, Dict, List, Optional, Set, Type
import sqlalchemy.databases
import sqlalchemy.dialects
from pkg_resources import iter_entry_points
from sqlalchemy.engine.default import DefaultDialect
+from sqlalchemy.engine.url import URL
from superset.db_engine_specs.base import BaseEngineSpec
logger = logging.getLogger(__name__)
-def is_engine_spec(attr: Any) -> bool:
+def is_engine_spec(obj: Any) -> bool:
+ """
+ Return true if a given object is a DB engine spec.
+ """
return (
- inspect.isclass(attr)
- and issubclass(attr, BaseEngineSpec)
- and attr != BaseEngineSpec
+ inspect.isclass(obj)
+ and issubclass(obj, BaseEngineSpec)
+ and obj != BaseEngineSpec
)
def load_engine_specs() -> List[Type[BaseEngineSpec]]:
+ """
+ Load all engine specs, native and 3rd party.
+ """
engine_specs: List[Type[BaseEngineSpec]] = []
# load standard engines
@@ -78,20 +85,31 @@ def load_engine_specs() -> List[Type[BaseEngineSpec]]:
return engine_specs
-def get_engine_specs() -> Dict[str, Type[BaseEngineSpec]]:
+def get_engine_spec(backend: str, driver: Optional[str] = None) -> Type[BaseEngineSpec]:
+ """
+ Return the DB engine spec associated with a given SQLAlchemy URL.
+
+ Note that if a driver is not specified the function returns the first DB engine spec
+ that supports the backend. Also, if a driver is specified but no DB engine explicitly
+ supporting that driver exists then a backend-only match is done, in order to allow new
+ drivers to work with Superset even if they are not listed in the DB engine spec
+ drivers.
+ """
engine_specs = load_engine_specs()
- # build map from name/alias -> spec
- engine_specs_map: Dict[str, Type[BaseEngineSpec]] = {}
- for engine_spec in engine_specs:
- names = [engine_spec.engine]
- if engine_spec.engine_aliases:
- names.extend(engine_spec.engine_aliases)
+ if driver is not None:
+ for engine_spec in engine_specs:
+ if engine_spec.supports_backend(backend, driver):
+ return engine_spec
- for name in names:
- engine_specs_map[name] = engine_spec
+ # check ignoring the driver, in order to support new drivers; this will return a
+ # random DB engine spec that supports the engine
+ for engine_spec in engine_specs:
+ if engine_spec.supports_backend(backend):
+ return engine_spec
- return engine_specs_map
+ # default to the generic DB engine spec
+ return BaseEngineSpec
# there's a mismatch between the dialect name reported by the driver in these
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 368770e261..1bf2f4a3f7 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -183,9 +183,15 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
having to add the same aggregation in SELECT.
"""
+ engine_name: Optional[str] = None # for user messages, overridden in child classes
+
+ # These attributes map the DB engine spec to one or more SQLAlchemy dialects/drivers;
+ # see the ``supports_url`` and ``supports_backend`` methods below.
engine = "base" # str as defined in sqlalchemy.engine.engine
engine_aliases: Set[str] = set()
- engine_name: Optional[str] = None # for user messages, overridden in child classes
+ drivers: Dict[str, str] = {}
+ default_driver: Optional[str] = None
+
_date_trunc_functions: Dict[str, str] = {}
_time_grain_expressions: Dict[Optional[str], str] = {}
column_type_mappings: Tuple[ColumnTypeMapping, ...] = (
@@ -355,6 +361,58 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]
] = {}
+ @classmethod
+ def supports_url(cls, url: URL) -> bool:
+ """
+ Returns true if the DB engine spec supports a given SQLAlchemy URL.
+
+ As an example, if a given DB engine spec has:
+
+ class PostgresDBEngineSpec:
+ engine = "postgresql"
+ engine_aliases = "postgres"
+ drivers = {
+ "psycopg2": "The default Postgres driver",
+ "asyncpg": "An asynchronous Postgres driver",
+ }
+
+ It would be used for all the following SQLAlchemy URIs:
+
+ - postgres://user:password@host/db
+ - postgresql://user:password@host/db
+ - postgres+asyncpg://user:password@host/db
+ - postgres+psycopg2://user:password@host/db
+ - postgresql+asyncpg://user:password@host/db
+ - postgresql+psycopg2://user:password@host/db
+
+ Note that SQLAlchemy has a default driver even if one is not specified:
+
+ >>> from sqlalchemy.engine.url import make_url
+ >>> make_url('postgres://').get_driver_name()
+ 'psycopg2'
+
+ """
+ backend = url.get_backend_name()
+ driver = url.get_driver_name()
+ return cls.supports_backend(backend, driver)
+
+ @classmethod
+ def supports_backend(cls, backend: str, driver: Optional[str] = None) -> bool:
+ """
+ Returns true if the DB engine spec supports a given SQLAlchemy backend/driver.
+ """
+ # check the backend first
+ if backend != cls.engine and backend not in cls.engine_aliases:
+ return False
+
+ # originally DB engine specs didn't declare any drivers and the check was made
+ # only on the engine; if that's the case, ignore the driver for backwards
+ # compatibility
+ if not cls.drivers or driver is None:
+ return True
+
+ return driver in cls.drivers
+
@classmethod
def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]:
"""
@@ -394,7 +452,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def get_text_clause(cls, clause: str) -> TextClause:
"""
- SQLALchemy wrapper to ensure text clauses are escaped properly
+ SQLAlchemy wrapper to ensure text clauses are escaped properly
:param clause: string clause with potentially unescaped characters
:return: text clause with escaped characters
diff --git a/superset/db_engine_specs/databricks.py b/superset/db_engine_specs/databricks.py
index 79718c93f6..90d90b9448 100644
--- a/superset/db_engine_specs/databricks.py
+++ b/superset/db_engine_specs/databricks.py
@@ -47,18 +47,23 @@ time_grain_expressions = {
class DatabricksHiveEngineSpec(HiveEngineSpec):
- engine = "databricks"
engine_name = "Databricks Interactive Cluster"
- driver = "pyhive"
+
+ engine = "databricks"
+ drivers = {"pyhive": "Hive driver for Interactive Cluster"}
+ default_driver = "pyhive"
+
_show_functions_column = "function"
_time_grain_expressions = time_grain_expressions
class DatabricksODBCEngineSpec(BaseEngineSpec):
- engine = "databricks"
engine_name = "Databricks SQL Endpoint"
- driver = "pyodbc"
+
+ engine = "databricks"
+ drivers = {"pyodbc": "ODBC driver for SQL endpoint"}
+ default_driver = "pyodbc"
_time_grain_expressions = time_grain_expressions
@@ -74,9 +79,11 @@ class DatabricksODBCEngineSpec(BaseEngineSpec):
class DatabricksNativeEngineSpec(DatabricksODBCEngineSpec):
- engine = "databricks"
engine_name = "Databricks Native Connector"
- driver = "connector"
+
+ engine = "databricks"
+ drivers = {"connector": "Native all-purpose driver"}
+ default_driver = "connector"
@staticmethod
def get_extra_params(database: "Database") -> Dict[str, Any]:
diff --git a/superset/db_engine_specs/shillelagh.py b/superset/db_engine_specs/shillelagh.py
index c6e6f618c7..3730122448 100644
--- a/superset/db_engine_specs/shillelagh.py
+++ b/superset/db_engine_specs/shillelagh.py
@@ -20,7 +20,11 @@ from superset.db_engine_specs.sqlite import SqliteEngineSpec
class ShillelaghEngineSpec(SqliteEngineSpec):
"""Engine for shillelagh"""
- engine = "shillelagh"
engine_name = "Shillelagh"
+ engine = "shillelagh"
+ drivers = {"apsw": "SQLite driver"}
+ default_driver = "apsw"
+ sqlalchemy_uri_placeholder = "shillelagh://"
+
allows_joins = True
allows_subqueries = True
diff --git a/superset/models/core.py b/superset/models/core.py
index b5a4aa6537..ec7ec79321 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -46,7 +46,7 @@ from sqlalchemy import (
from sqlalchemy.engine import Connection, Dialect, Engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
-from sqlalchemy.exc import ArgumentError
+from sqlalchemy.exc import ArgumentError, NoSuchModuleError
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import relationship
from sqlalchemy.pool import NullPool
@@ -635,15 +635,20 @@ class Database(
@property
def db_engine_spec(self) -> Type[db_engine_specs.BaseEngineSpec]:
- return self.get_db_engine_spec_for_backend(self.backend)
+ url = make_url_safe(self.sqlalchemy_uri_decrypted)
+ return self.get_db_engine_spec(url)
@classmethod
@memoized
- def get_db_engine_spec_for_backend(
- cls, backend: str
- ) -> Type[db_engine_specs.BaseEngineSpec]:
- engines = db_engine_specs.get_engine_specs()
- return engines.get(backend, db_engine_specs.BaseEngineSpec)
+ def get_db_engine_spec(cls, url: URL) -> Type[db_engine_specs.BaseEngineSpec]:
+ backend = url.get_backend_name()
+ try:
+ driver = url.get_driver_name()
+ except NoSuchModuleError:
+ # can't load the driver, fallback for backwards compatibility
+ driver = None
+
+ return db_engine_specs.get_engine_spec(backend, driver)
def grains(self) -> Tuple[TimeGrain, ...]:
"""Defines time granularity database-specific expressions.
diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py
index 8ff12b2406..b53418fb16 100644
--- a/tests/integration_tests/databases/api_tests.py
+++ b/tests/integration_tests/databases/api_tests.py
@@ -1425,7 +1425,7 @@ class TestDatabaseApi(SupersetTestCase):
expected_response = {
"errors": [
{
- "message": "Could not load database driver: AzureSynapseSpec",
+ "message": "Could not load database driver: MssqlEngineSpec",
"error_type": "GENERIC_COMMAND_ERROR",
"level": "warning",
"extra": {
diff --git a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py
index 07f9bfcf31..f998444f31 100644
--- a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py
+++ b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py
@@ -20,7 +20,7 @@ from unittest import mock
import pytest
from superset.connectors.sqla.models import TableColumn
-from superset.db_engine_specs import get_engine_specs
+from superset.db_engine_specs import load_engine_specs
from superset.db_engine_specs.base import (
BaseEngineSpec,
BasicParametersMixin,
@@ -195,7 +195,7 @@ class TestDbEngineSpecs(TestDbEngineSpec):
def test_engine_time_grain_validity(self):
time_grains = set(builtin_time_grains.keys())
# loop over all subclasses of BaseEngineSpec
- for engine in get_engine_specs().values():
+ for engine in load_engine_specs():
if engine is not BaseEngineSpec:
# make sure time grain functions have been defined
self.assertGreater(len(engine.get_time_grain_expressions()), 0)
diff --git a/tests/integration_tests/db_engine_specs/postgres_tests.py b/tests/integration_tests/db_engine_specs/postgres_tests.py
index e6eb4fc1d1..79a307a488 100644
--- a/tests/integration_tests/db_engine_specs/postgres_tests.py
+++ b/tests/integration_tests/db_engine_specs/postgres_tests.py
@@ -20,7 +20,7 @@ from unittest import mock
from sqlalchemy import column, literal_column
from sqlalchemy.dialects import postgresql
-from superset.db_engine_specs import get_engine_specs
+from superset.db_engine_specs import load_engine_specs
from superset.db_engine_specs.postgres import PostgresEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.models.sql_lab import Query
@@ -137,7 +137,11 @@ class TestPostgresDbEngineSpec(TestDbEngineSpec):
"""
DB Eng Specs (postgres): Test "postgres" in engine spec
"""
- self.assertIn("postgres", get_engine_specs())
+ backends = set()
+ for engine in load_engine_specs():
+ backends.add(engine.engine)
+ backends.update(engine.engine_aliases)
+ assert "postgres" in backends
def test_extras_without_ssl(self):
db = mock.Mock()
diff --git a/tests/integration_tests/databases/schema_tests.py b/tests/unit_tests/databases/schema_tests.py
similarity index 57%
rename from tests/integration_tests/databases/schema_tests.py
rename to tests/unit_tests/databases/schema_tests.py
index 1f8ca067f6..58a1f6389d 100644
--- a/tests/integration_tests/databases/schema_tests.py
+++ b/tests/unit_tests/databases/schema_tests.py
@@ -15,31 +15,59 @@
# specific language governing permissions and limitations
# under the License.
-from unittest import mock
+# pylint: disable=import-outside-toplevel, invalid-name, unused-argument, redefined-outer-name
+from typing import TYPE_CHECKING
+
+import pytest
from marshmallow import fields, Schema, ValidationError
+from pytest_mock import MockFixture
+
+if TYPE_CHECKING:
+ from superset.databases.schemas import DatabaseParametersSchemaMixin
+ from superset.db_engine_specs.base import BasicParametersMixin
-from superset.databases.schemas import DatabaseParametersSchemaMixin
-from superset.db_engine_specs.base import BasicParametersMixin
-from superset.models.core import ConfigurationMethod
+# pylint: disable=too-few-public-methods
+class InvalidEngine:
+ """
+ An invalid DB engine spec.
+ """
-class DummySchema(Schema, DatabaseParametersSchemaMixin):
- sqlalchemy_uri = fields.String()
+@pytest.fixture
+def dummy_schema() -> "DatabaseParametersSchemaMixin":
+ """
+ Fixture providing a dummy schema.
+ """
+ from superset.databases.schemas import DatabaseParametersSchemaMixin
-class DummyEngine(BasicParametersMixin):
- engine = "dummy"
- default_driver = "dummy"
+ class DummySchema(Schema, DatabaseParametersSchemaMixin):
+ sqlalchemy_uri = fields.String()
+ return DummySchema()
+
+
+@pytest.fixture
+def dummy_engine(mocker: MockFixture) -> None:
+ """
+ Fixture proving a dummy DB engine spec.
+ """
+ from superset.db_engine_specs.base import BasicParametersMixin
+
+ class DummyEngine(BasicParametersMixin):
+ engine = "dummy"
+ default_driver = "dummy"
+
+ mocker.patch("superset.databases.schemas.get_engine_spec", return_value=DummyEngine)
-class InvalidEngine:
- pass
+def test_database_parameters_schema_mixin(
+ dummy_engine: None,
+ dummy_schema: "Schema",
+) -> None:
+ from superset.models.core import ConfigurationMethod
-@mock.patch("superset.databases.schemas.get_engine_specs")
-def test_database_parameters_schema_mixin(get_engine_specs):
- get_engine_specs.return_value = {"dummy_engine": DummyEngine}
payload = {
"engine": "dummy_engine",
"configuration_method": ConfigurationMethod.DYNAMIC_FORM,
@@ -51,15 +79,18 @@ def test_database_parameters_schema_mixin(get_engine_specs):
"database": "dbname",
},
}
- schema = DummySchema()
- result = schema.load(payload)
+ result = dummy_schema.load(payload)
assert result == {
"configuration_method": ConfigurationMethod.DYNAMIC_FORM,
"sqlalchemy_uri": "dummy+dummy://username:password@localhost:12345/dbname",
}
-def test_database_parameters_schema_mixin_no_engine():
+def test_database_parameters_schema_mixin_no_engine(
+ dummy_schema: "Schema",
+) -> None:
+ from superset.models.core import ConfigurationMethod
+
payload = {
"configuration_method": ConfigurationMethod.DYNAMIC_FORM,
"parameters": {
@@ -67,23 +98,28 @@ def test_database_parameters_schema_mixin_no_engine():
"password": "password",
"host": "localhost",
"port": 12345,
- "dbname": "dbname",
+ "database": "dbname",
},
}
- schema = DummySchema()
try:
- schema.load(payload)
+ dummy_schema.load(payload)
except ValidationError as err:
assert err.messages == {
"_schema": [
- "An engine must be specified when passing individual parameters to a database."
+ (
+ "An engine must be specified when passing individual parameters to "
+ "a database."
+ ),
]
}
-@mock.patch("superset.databases.schemas.get_engine_specs")
-def test_database_parameters_schema_mixin_invalid_engine(get_engine_specs):
- get_engine_specs.return_value = {}
+def test_database_parameters_schema_mixin_invalid_engine(
+ dummy_engine: None,
+ dummy_schema: "Schema",
+) -> None:
+ from superset.models.core import ConfigurationMethod
+
payload = {
"engine": "dummy_engine",
"configuration_method": ConfigurationMethod.DYNAMIC_FORM,
@@ -92,21 +128,24 @@ def test_database_parameters_schema_mixin_invalid_engine(get_engine_specs):
"password": "password",
"host": "localhost",
"port": 12345,
- "dbname": "dbname",
+ "database": "dbname",
},
}
- schema = DummySchema()
try:
- schema.load(payload)
+ dummy_schema.load(payload)
except ValidationError as err:
+ print(err.messages)
assert err.messages == {
"_schema": ['Engine "dummy_engine" is not a valid engine.']
}
-@mock.patch("superset.databases.schemas.get_engine_specs")
-def test_database_parameters_schema_no_mixin(get_engine_specs):
- get_engine_specs.return_value = {"invalid_engine": InvalidEngine}
+def test_database_parameters_schema_no_mixin(
+ dummy_engine: None,
+ dummy_schema: "Schema",
+) -> None:
+ from superset.models.core import ConfigurationMethod
+
payload = {
"engine": "invalid_engine",
"configuration_method": ConfigurationMethod.DYNAMIC_FORM,
@@ -118,9 +157,8 @@ def test_database_parameters_schema_no_mixin(get_engine_specs):
"database": "dbname",
},
}
- schema = DummySchema()
try:
- schema.load(payload)
+ dummy_schema.load(payload)
except ValidationError as err:
assert err.messages == {
"_schema": [
@@ -132,9 +170,12 @@ def test_database_parameters_schema_no_mixin(get_engine_specs):
}
-@mock.patch("superset.databases.schemas.get_engine_specs")
-def test_database_parameters_schema_mixin_invalid_type(get_engine_specs):
- get_engine_specs.return_value = {"dummy_engine": DummyEngine}
+def test_database_parameters_schema_mixin_invalid_type(
+ dummy_engine: None,
+ dummy_schema: "Schema",
+) -> None:
+ from superset.models.core import ConfigurationMethod
+
payload = {
"engine": "dummy_engine",
"configuration_method": ConfigurationMethod.DYNAMIC_FORM,
@@ -146,8 +187,7 @@ def test_database_parameters_schema_mixin_invalid_type(get_engine_specs):
"database": "dbname",
},
}
- schema = DummySchema()
try:
- schema.load(payload)
+ dummy_schema.load(payload)
except ValidationError as err:
assert err.messages == {"port": ["Not a valid integer."]}
diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py
index 3338ddcb61..5eb60dc6f9 100644
--- a/tests/unit_tests/models/core_test.py
+++ b/tests/unit_tests/models/core_test.py
@@ -59,7 +59,7 @@ def test_get_metrics(mocker: MockFixture) -> None:
},
]
- database.get_db_engine_spec_for_backend = mocker.MagicMock( # type: ignore
+ database.get_db_engine_spec = mocker.MagicMock( # type: ignore
return_value=CustomSqliteEngineSpec
)
assert database.get_metrics("table") == [
@@ -70,3 +70,78 @@ def test_get_metrics(mocker: MockFixture) -> None:
"verbose_name": "COUNT(DISTINCT user_id)",
},
]
+
+
+def test_get_db_engine_spec(mocker: MockFixture) -> None:
+ """
+ Tests for ``get_db_engine_spec``.
+ """
+ from superset.db_engine_specs import BaseEngineSpec
+ from superset.models.core import Database
+
+ # pylint: disable=abstract-method
+ class PostgresDBEngineSpec(BaseEngineSpec):
+ """
+ A DB engine spec with drivers and a default driver.
+ """
+
+ engine = "postgresql"
+ engine_aliases = {"postgres"}
+ drivers = {
+ "psycopg2": "The default Postgres driver",
+ "asyncpg": "An async Postgres driver",
+ }
+ default_driver = "psycopg2"
+
+ # pylint: disable=abstract-method
+ class OldDBEngineSpec(BaseEngineSpec):
+ """
+ And old DB engine spec without drivers nor a default driver.
+ """
+
+ engine = "mysql"
+
+ load_engine_specs = mocker.patch("superset.db_engine_specs.load_engine_specs")
+ load_engine_specs.return_value = [
+ PostgresDBEngineSpec,
+ OldDBEngineSpec,
+ ]
+
+ assert (
+ Database(database_name="db", sqlalchemy_uri="postgresql://").db_engine_spec
+ == PostgresDBEngineSpec
+ )
+ assert (
+ Database(
+ database_name="db", sqlalchemy_uri="postgresql+psycopg2://"
+ ).db_engine_spec
+ == PostgresDBEngineSpec
+ )
+ assert (
+ Database(
+ database_name="db", sqlalchemy_uri="postgresql+asyncpg://"
+ ).db_engine_spec
+ == PostgresDBEngineSpec
+ )
+ assert (
+ Database(
+ database_name="db", sqlalchemy_uri="postgresql+fancynewdriver://"
+ ).db_engine_spec
+ == PostgresDBEngineSpec
+ )
+ assert (
+ Database(database_name="db", sqlalchemy_uri="mysql://").db_engine_spec
+ == OldDBEngineSpec
+ )
+ assert (
+ Database(
+ database_name="db", sqlalchemy_uri="mysql+mysqlconnector://"
+ ).db_engine_spec
+ == OldDBEngineSpec
+ )
+ assert (
+ Database(
+ database_name="db", sqlalchemy_uri="mysql+fancynewdriver://"
+ ).db_engine_spec
+ == OldDBEngineSpec
+ )