You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by di...@apache.org on 2024/02/07 16:03:26 UTC
(superset) branch master updated: chore: Remove database ID dependency for SSH Tunnel creation (#26989)
This is an automated email from the ASF dual-hosted git repository.
diegopucci pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new d8e26cfff1 chore: Remove database ID dependency for SSH Tunnel creation (#26989)
d8e26cfff1 is described below
commit d8e26cfff1a38155ad54ba65741049d7b60346e6
Author: Geido <60...@users.noreply.github.com>
AuthorDate: Wed Feb 7 18:03:19 2024 +0200
chore: Remove database ID dependency for SSH Tunnel creation (#26989)
---
superset/commands/database/create.py | 58 ++++++++++++----------
superset/commands/database/ssh_tunnel/create.py | 23 +++------
superset/commands/database/update.py | 2 +-
tests/integration_tests/databases/api_tests.py | 16 +++---
.../ssh_tunnel/commands/commands_tests.py | 17 -------
.../databases/ssh_tunnel/commands/create_test.py | 6 +--
6 files changed, 50 insertions(+), 72 deletions(-)
diff --git a/superset/commands/database/create.py b/superset/commands/database/create.py
index a012e9b2a5..cde9dd8e88 100644
--- a/superset/commands/database/create.py
+++ b/superset/commands/database/create.py
@@ -41,6 +41,7 @@ from superset.daos.database import DatabaseDAO
from superset.daos.exceptions import DAOCreateFailedError
from superset.exceptions import SupersetErrorsException
from superset.extensions import db, event_logger, security_manager
+from superset.models.core import Database
logger = logging.getLogger(__name__)
stats_logger = current_app.config["STATS_LOGGER"]
@@ -76,34 +77,20 @@ class CreateDatabaseCommand(BaseCommand):
"{}",
)
+ ssh_tunnel = None
+
try:
- database = DatabaseDAO.create(attributes=self._properties, commit=False)
- database.set_sqlalchemy_uri(database.sqlalchemy_uri)
+ database = self._create_database()
- ssh_tunnel = None
if ssh_tunnel_properties := self._properties.get("ssh_tunnel"):
if not is_feature_enabled("SSH_TUNNELING"):
- db.session.rollback()
raise SSHTunnelingNotEnabledError()
- try:
- # So database.id is not None
- db.session.flush()
- ssh_tunnel = CreateSSHTunnelCommand(
- database.id, ssh_tunnel_properties
- ).run()
- except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex:
- event_logger.log_with_context(
- action=f"db_creation_failed.{ex.__class__.__name__}.ssh_tunnel",
- engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
- )
- # So we can show the original message
- raise ex
- except Exception as ex:
- event_logger.log_with_context(
- action=f"db_creation_failed.{ex.__class__.__name__}.ssh_tunnel",
- engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
- )
- raise DatabaseCreateFailedError() from ex
+
+ ssh_tunnel = CreateSSHTunnelCommand(
+ database, ssh_tunnel_properties
+ ).run()
+
+ db.session.commit()
# adding a new database we always want to force refresh schema list
schemas = database.get_all_schema_names(cache=False, ssh_tunnel=ssh_tunnel)
@@ -112,9 +99,23 @@ class CreateDatabaseCommand(BaseCommand):
"schema_access", security_manager.get_schema_perm(database, schema)
)
- db.session.commit()
-
- except DAOCreateFailedError as ex:
+ except (
+ SSHTunnelInvalidError,
+ SSHTunnelCreateFailedError,
+ SSHTunnelingNotEnabledError,
+ ) as ex:
+ db.session.rollback()
+ event_logger.log_with_context(
+ action=f"db_creation_failed.{ex.__class__.__name__}.ssh_tunnel",
+ engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
+ )
+ # So we can show the original message
+ raise ex
+ except (
+ DAOCreateFailedError,
+ DatabaseInvalidError,
+ Exception,
+ ) as ex:
db.session.rollback()
event_logger.log_with_context(
action=f"db_creation_failed.{ex.__class__.__name__}",
@@ -150,3 +151,8 @@ class CreateDatabaseCommand(BaseCommand):
)
)
raise exception
+
+ def _create_database(self) -> Database:
+ database = DatabaseDAO.create(attributes=self._properties, commit=False)
+ database.set_sqlalchemy_uri(database.sqlalchemy_uri)
+ return database
diff --git a/superset/commands/database/ssh_tunnel/create.py b/superset/commands/database/ssh_tunnel/create.py
index 07209f010b..cbfee3ce2a 100644
--- a/superset/commands/database/ssh_tunnel/create.py
+++ b/superset/commands/database/ssh_tunnel/create.py
@@ -28,39 +28,32 @@ from superset.commands.database.ssh_tunnel.exceptions import (
)
from superset.daos.database import SSHTunnelDAO
from superset.daos.exceptions import DAOCreateFailedError
-from superset.extensions import db, event_logger
+from superset.extensions import event_logger
+from superset.models.core import Database
logger = logging.getLogger(__name__)
class CreateSSHTunnelCommand(BaseCommand):
- def __init__(self, database_id: int, data: dict[str, Any]):
+ def __init__(self, database: Database, data: dict[str, Any]):
self._properties = data.copy()
- self._properties["database_id"] = database_id
+ self._properties["database"] = database
def run(self) -> Model:
try:
- # Start nested transaction since we are always creating the tunnel
- # through a DB command (Create or Update). Without this, we cannot
- # safely rollback changes to databases if any, i.e, things like
- # test_do_not_create_database_if_ssh_tunnel_creation_fails test will fail
- db.session.begin_nested()
self.validate()
- return SSHTunnelDAO.create(attributes=self._properties, commit=False)
+ ssh_tunnel = SSHTunnelDAO.create(attributes=self._properties, commit=False)
+ return ssh_tunnel
except DAOCreateFailedError as ex:
- # Rollback nested transaction
- db.session.rollback()
raise SSHTunnelCreateFailedError() from ex
except SSHTunnelInvalidError as ex:
- # Rollback nested transaction
- db.session.rollback()
raise ex
def validate(self) -> None:
# TODO(hughhh): check to make sure the server port is not localhost
# using the config.SSH_TUNNEL_MANAGER
+
exceptions: list[ValidationError] = []
- database_id: Optional[int] = self._properties.get("database_id")
server_address: Optional[str] = self._properties.get("server_address")
server_port: Optional[int] = self._properties.get("server_port")
username: Optional[str] = self._properties.get("username")
@@ -68,8 +61,6 @@ class CreateSSHTunnelCommand(BaseCommand):
private_key_password: Optional[str] = self._properties.get(
"private_key_password"
)
- if not database_id:
- exceptions.append(SSHTunnelRequiredFieldValidationError("database_id"))
if not server_address:
exceptions.append(SSHTunnelRequiredFieldValidationError("server_address"))
if not server_port:
diff --git a/superset/commands/database/update.py b/superset/commands/database/update.py
index 039d731d72..edc0ba1b98 100644
--- a/superset/commands/database/update.py
+++ b/superset/commands/database/update.py
@@ -78,7 +78,7 @@ class UpdateDatabaseCommand(BaseCommand):
if existing_ssh_tunnel_model is None:
# We couldn't found an existing tunnel so we need to create one
try:
- CreateSSHTunnelCommand(database.id, ssh_tunnel_properties).run()
+ CreateSSHTunnelCommand(database, ssh_tunnel_properties).run()
except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex:
# So we can show the original message
raise ex
diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py
index 0bc1f245a1..f7b8cc0ec8 100644
--- a/tests/integration_tests/databases/api_tests.py
+++ b/tests/integration_tests/databases/api_tests.py
@@ -538,14 +538,16 @@ class TestDatabaseApi(SupersetTestCase):
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
+ @mock.patch("superset.extensions.db.session.rollback")
def test_do_not_create_database_if_ssh_tunnel_creation_fails(
self,
+ mock_rollback,
mock_test_connection_database_command_run,
mock_create_is_feature_enabled,
mock_get_all_schema_names,
):
"""
- Database API: Test Database is not created if SSH Tunnel creation fails
+ Database API: Test rollback is called if SSH Tunnel creation fails
"""
mock_create_is_feature_enabled.return_value = True
self.login(username="admin")
@@ -566,6 +568,7 @@ class TestDatabaseApi(SupersetTestCase):
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 422)
+
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
@@ -573,14 +576,9 @@ class TestDatabaseApi(SupersetTestCase):
)
assert model_ssh_tunnel is None
self.assertEqual(response, fail_message)
- # Cleanup
- model = (
- db.session.query(Database)
- .filter(Database.database_name == "test-db-failure-ssh-tunnel")
- .one_or_none()
- )
- # the DB should not be created
- assert model is None
+
+ # Check that rollback was called
+ mock_rollback.assert_called()
@mock.patch(
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
diff --git a/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py b/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py
index 1cd9afcc80..f6e5ca9d09 100644
--- a/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py
+++ b/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py
@@ -30,23 +30,6 @@ from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand
from tests.integration_tests.base_tests import SupersetTestCase
-class TestCreateSSHTunnelCommand(SupersetTestCase):
- @mock.patch("superset.utils.core.g")
- def test_create_invalid_database_id(self, mock_g):
- mock_g.user = security_manager.find_user("admin")
- command = CreateSSHTunnelCommand(
- None,
- {
- "server_address": "127.0.0.1",
- "server_port": 5432,
- "username": "test_user",
- },
- )
- with pytest.raises(SSHTunnelInvalidError) as excinfo:
- command.run()
- assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")
-
-
class TestUpdateSSHTunnelCommand(SupersetTestCase):
@mock.patch("superset.utils.core.g")
def test_update_ssh_tunnel_not_found(self, mock_g):
diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py
index bd891b64f0..1777bdc2e1 100644
--- a/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py
+++ b/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py
@@ -37,7 +37,7 @@ def test_create_ssh_tunnel_command() -> None:
"password": "bar",
}
- result = CreateSSHTunnelCommand(db.id, properties).run()
+ result = CreateSSHTunnelCommand(db, properties).run()
assert result is not None
assert isinstance(result, SSHTunnel)
@@ -53,14 +53,14 @@ def test_create_ssh_tunnel_command_invalid_params() -> None:
# If we are trying to create a tunnel with a private_key_password
# then a private_key is mandatory
properties = {
- "database_id": db.id,
+ "database": db,
"server_address": "123.132.123.1",
"server_port": "3005",
"username": "foo",
"private_key_password": "bar",
}
- command = CreateSSHTunnelCommand(db.id, properties)
+ command = CreateSSHTunnelCommand(db, properties)
with pytest.raises(SSHTunnelInvalidError) as excinfo:
command.run()