You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by di...@apache.org on 2024/02/07 16:03:26 UTC

(superset) branch master updated: chore: Remove database ID dependency for SSH Tunnel creation (#26989)

This is an automated email from the ASF dual-hosted git repository.

diegopucci pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new d8e26cfff1 chore: Remove database ID dependency for SSH Tunnel creation (#26989)
d8e26cfff1 is described below

commit d8e26cfff1a38155ad54ba65741049d7b60346e6
Author: Geido <60...@users.noreply.github.com>
AuthorDate: Wed Feb 7 18:03:19 2024 +0200

    chore: Remove database ID dependency for SSH Tunnel creation (#26989)
---
 superset/commands/database/create.py               | 58 ++++++++++++----------
 superset/commands/database/ssh_tunnel/create.py    | 23 +++------
 superset/commands/database/update.py               |  2 +-
 tests/integration_tests/databases/api_tests.py     | 16 +++---
 .../ssh_tunnel/commands/commands_tests.py          | 17 -------
 .../databases/ssh_tunnel/commands/create_test.py   |  6 +--
 6 files changed, 50 insertions(+), 72 deletions(-)

diff --git a/superset/commands/database/create.py b/superset/commands/database/create.py
index a012e9b2a5..cde9dd8e88 100644
--- a/superset/commands/database/create.py
+++ b/superset/commands/database/create.py
@@ -41,6 +41,7 @@ from superset.daos.database import DatabaseDAO
 from superset.daos.exceptions import DAOCreateFailedError
 from superset.exceptions import SupersetErrorsException
 from superset.extensions import db, event_logger, security_manager
+from superset.models.core import Database
 
 logger = logging.getLogger(__name__)
 stats_logger = current_app.config["STATS_LOGGER"]
@@ -76,34 +77,20 @@ class CreateDatabaseCommand(BaseCommand):
             "{}",
         )
 
+        ssh_tunnel = None
+
         try:
-            database = DatabaseDAO.create(attributes=self._properties, commit=False)
-            database.set_sqlalchemy_uri(database.sqlalchemy_uri)
+            database = self._create_database()
 
-            ssh_tunnel = None
             if ssh_tunnel_properties := self._properties.get("ssh_tunnel"):
                 if not is_feature_enabled("SSH_TUNNELING"):
-                    db.session.rollback()
                     raise SSHTunnelingNotEnabledError()
-                try:
-                    # So database.id is not None
-                    db.session.flush()
-                    ssh_tunnel = CreateSSHTunnelCommand(
-                        database.id, ssh_tunnel_properties
-                    ).run()
-                except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex:
-                    event_logger.log_with_context(
-                        action=f"db_creation_failed.{ex.__class__.__name__}.ssh_tunnel",
-                        engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
-                    )
-                    # So we can show the original message
-                    raise ex
-                except Exception as ex:
-                    event_logger.log_with_context(
-                        action=f"db_creation_failed.{ex.__class__.__name__}.ssh_tunnel",
-                        engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
-                    )
-                    raise DatabaseCreateFailedError() from ex
+
+                ssh_tunnel = CreateSSHTunnelCommand(
+                    database, ssh_tunnel_properties
+                ).run()
+
+            db.session.commit()
 
             # adding a new database we always want to force refresh schema list
             schemas = database.get_all_schema_names(cache=False, ssh_tunnel=ssh_tunnel)
@@ -112,9 +99,23 @@ class CreateDatabaseCommand(BaseCommand):
                     "schema_access", security_manager.get_schema_perm(database, schema)
                 )
 
-            db.session.commit()
-
-        except DAOCreateFailedError as ex:
+        except (
+            SSHTunnelInvalidError,
+            SSHTunnelCreateFailedError,
+            SSHTunnelingNotEnabledError,
+        ) as ex:
+            db.session.rollback()
+            event_logger.log_with_context(
+                action=f"db_creation_failed.{ex.__class__.__name__}.ssh_tunnel",
+                engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
+            )
+            # So we can show the original message
+            raise ex
+        except (
+            DAOCreateFailedError,
+            DatabaseInvalidError,
+            Exception,
+        ) as ex:
             db.session.rollback()
             event_logger.log_with_context(
                 action=f"db_creation_failed.{ex.__class__.__name__}",
@@ -150,3 +151,8 @@ class CreateDatabaseCommand(BaseCommand):
                 )
             )
             raise exception
+
+    def _create_database(self) -> Database:
+        database = DatabaseDAO.create(attributes=self._properties, commit=False)
+        database.set_sqlalchemy_uri(database.sqlalchemy_uri)
+        return database
diff --git a/superset/commands/database/ssh_tunnel/create.py b/superset/commands/database/ssh_tunnel/create.py
index 07209f010b..cbfee3ce2a 100644
--- a/superset/commands/database/ssh_tunnel/create.py
+++ b/superset/commands/database/ssh_tunnel/create.py
@@ -28,39 +28,32 @@ from superset.commands.database.ssh_tunnel.exceptions import (
 )
 from superset.daos.database import SSHTunnelDAO
 from superset.daos.exceptions import DAOCreateFailedError
-from superset.extensions import db, event_logger
+from superset.extensions import event_logger
+from superset.models.core import Database
 
 logger = logging.getLogger(__name__)
 
 
 class CreateSSHTunnelCommand(BaseCommand):
-    def __init__(self, database_id: int, data: dict[str, Any]):
+    def __init__(self, database: Database, data: dict[str, Any]):
         self._properties = data.copy()
-        self._properties["database_id"] = database_id
+        self._properties["database"] = database
 
     def run(self) -> Model:
         try:
-            # Start nested transaction since we are always creating the tunnel
-            # through a DB command (Create or Update). Without this, we cannot
-            # safely rollback changes to databases if any, i.e, things like
-            # test_do_not_create_database_if_ssh_tunnel_creation_fails test will fail
-            db.session.begin_nested()
             self.validate()
-            return SSHTunnelDAO.create(attributes=self._properties, commit=False)
+            ssh_tunnel = SSHTunnelDAO.create(attributes=self._properties, commit=False)
+            return ssh_tunnel
         except DAOCreateFailedError as ex:
-            # Rollback nested transaction
-            db.session.rollback()
             raise SSHTunnelCreateFailedError() from ex
         except SSHTunnelInvalidError as ex:
-            # Rollback nested transaction
-            db.session.rollback()
             raise ex
 
     def validate(self) -> None:
         # TODO(hughhh): check to make sure the server port is not localhost
         # using the config.SSH_TUNNEL_MANAGER
+
         exceptions: list[ValidationError] = []
-        database_id: Optional[int] = self._properties.get("database_id")
         server_address: Optional[str] = self._properties.get("server_address")
         server_port: Optional[int] = self._properties.get("server_port")
         username: Optional[str] = self._properties.get("username")
@@ -68,8 +61,6 @@ class CreateSSHTunnelCommand(BaseCommand):
         private_key_password: Optional[str] = self._properties.get(
             "private_key_password"
         )
-        if not database_id:
-            exceptions.append(SSHTunnelRequiredFieldValidationError("database_id"))
         if not server_address:
             exceptions.append(SSHTunnelRequiredFieldValidationError("server_address"))
         if not server_port:
diff --git a/superset/commands/database/update.py b/superset/commands/database/update.py
index 039d731d72..edc0ba1b98 100644
--- a/superset/commands/database/update.py
+++ b/superset/commands/database/update.py
@@ -78,7 +78,7 @@ class UpdateDatabaseCommand(BaseCommand):
                 if existing_ssh_tunnel_model is None:
                     # We couldn't found an existing tunnel so we need to create one
                     try:
-                        CreateSSHTunnelCommand(database.id, ssh_tunnel_properties).run()
+                        CreateSSHTunnelCommand(database, ssh_tunnel_properties).run()
                     except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex:
                         # So we can show the original message
                         raise ex
diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py
index 0bc1f245a1..f7b8cc0ec8 100644
--- a/tests/integration_tests/databases/api_tests.py
+++ b/tests/integration_tests/databases/api_tests.py
@@ -538,14 +538,16 @@ class TestDatabaseApi(SupersetTestCase):
     @mock.patch(
         "superset.models.core.Database.get_all_schema_names",
     )
+    @mock.patch("superset.extensions.db.session.rollback")
     def test_do_not_create_database_if_ssh_tunnel_creation_fails(
         self,
+        mock_rollback,
         mock_test_connection_database_command_run,
         mock_create_is_feature_enabled,
         mock_get_all_schema_names,
     ):
         """
-        Database API: Test Database is not created if SSH Tunnel creation fails
+        Database API: Test rollback is called if SSH Tunnel creation fails
         """
         mock_create_is_feature_enabled.return_value = True
         self.login(username="admin")
@@ -566,6 +568,7 @@ class TestDatabaseApi(SupersetTestCase):
         rv = self.client.post(uri, json=database_data)
         response = json.loads(rv.data.decode("utf-8"))
         self.assertEqual(rv.status_code, 422)
+
         model_ssh_tunnel = (
             db.session.query(SSHTunnel)
             .filter(SSHTunnel.database_id == response.get("id"))
@@ -573,14 +576,9 @@ class TestDatabaseApi(SupersetTestCase):
         )
         assert model_ssh_tunnel is None
         self.assertEqual(response, fail_message)
-        # Cleanup
-        model = (
-            db.session.query(Database)
-            .filter(Database.database_name == "test-db-failure-ssh-tunnel")
-            .one_or_none()
-        )
-        # the DB should not be created
-        assert model is None
+
+        # Check that rollback was called
+        mock_rollback.assert_called()
 
     @mock.patch(
         "superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
diff --git a/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py b/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py
index 1cd9afcc80..f6e5ca9d09 100644
--- a/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py
+++ b/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py
@@ -30,23 +30,6 @@ from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand
 from tests.integration_tests.base_tests import SupersetTestCase
 
 
-class TestCreateSSHTunnelCommand(SupersetTestCase):
-    @mock.patch("superset.utils.core.g")
-    def test_create_invalid_database_id(self, mock_g):
-        mock_g.user = security_manager.find_user("admin")
-        command = CreateSSHTunnelCommand(
-            None,
-            {
-                "server_address": "127.0.0.1",
-                "server_port": 5432,
-                "username": "test_user",
-            },
-        )
-        with pytest.raises(SSHTunnelInvalidError) as excinfo:
-            command.run()
-        assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")
-
-
 class TestUpdateSSHTunnelCommand(SupersetTestCase):
     @mock.patch("superset.utils.core.g")
     def test_update_ssh_tunnel_not_found(self, mock_g):
diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py
index bd891b64f0..1777bdc2e1 100644
--- a/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py
+++ b/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py
@@ -37,7 +37,7 @@ def test_create_ssh_tunnel_command() -> None:
         "password": "bar",
     }
 
-    result = CreateSSHTunnelCommand(db.id, properties).run()
+    result = CreateSSHTunnelCommand(db, properties).run()
 
     assert result is not None
     assert isinstance(result, SSHTunnel)
@@ -53,14 +53,14 @@ def test_create_ssh_tunnel_command_invalid_params() -> None:
     # If we are trying to create a tunnel with a private_key_password
     # then a private_key is mandatory
     properties = {
-        "database_id": db.id,
+        "database": db,
         "server_address": "123.132.123.1",
         "server_port": "3005",
         "username": "foo",
         "private_key_password": "bar",
     }
 
-    command = CreateSSHTunnelCommand(db.id, properties)
+    command = CreateSSHTunnelCommand(db, properties)
 
     with pytest.raises(SSHTunnelInvalidError) as excinfo:
         command.run()