You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by el...@apache.org on 2023/01/27 00:53:43 UTC

[superset] branch master updated: feat(ssh_tunnel): Add feature flag to SSH Tunnel API (#22805)

This is an automated email from the ASF dual-hosted git repository.

elizabeth pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new d6a4a5da79 feat(ssh_tunnel): Add feature flag to SSH Tunnel API (#22805)
d6a4a5da79 is described below

commit d6a4a5da7976070cb949409763b22519a0d3f379
Author: Antonio Rivero Martinez <38...@users.noreply.github.com>
AuthorDate: Thu Jan 26 21:53:36 2023 -0300

    feat(ssh_tunnel): Add feature flag to SSH Tunnel API (#22805)
---
 superset/databases/api.py                          |  20 +++-
 superset/databases/commands/create.py              |   7 +-
 superset/databases/commands/test_connection.py     |  17 +++-
 superset/databases/commands/update.py              |   8 +-
 superset/databases/ssh_tunnel/commands/delete.py   |   4 +
 .../databases/ssh_tunnel/commands/exceptions.py    |   5 +
 tests/integration_tests/databases/api_tests.py     | 106 +++++++++++++++++++--
 .../ssh_tunnel/commands/commands_tests.py          |   4 +-
 tests/unit_tests/databases/api_test.py             |   8 ++
 .../databases/ssh_tunnel/commands/delete_test.py   |  11 ++-
 10 files changed, 171 insertions(+), 19 deletions(-)

diff --git a/superset/databases/api.py b/superset/databases/api.py
index 4866cbe775..572f3b340a 100644
--- a/superset/databases/api.py
+++ b/superset/databases/api.py
@@ -75,6 +75,7 @@ from superset.databases.schemas import (
 from superset.databases.ssh_tunnel.commands.delete import DeleteSSHTunnelCommand
 from superset.databases.ssh_tunnel.commands.exceptions import (
     SSHTunnelDeleteFailedError,
+    SSHTunnelingNotEnabledError,
     SSHTunnelNotFoundError,
 )
 from superset.databases.utils import get_table_metadata
@@ -349,6 +350,8 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
                 exc_info=True,
             )
             return self.response_422(message=str(ex))
+        except SSHTunnelingNotEnabledError as ex:
+            return self.response_400(message=str(ex))
         except SupersetException as ex:
             return self.response(ex.status, message=ex.message)
 
@@ -433,6 +436,8 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
                 exc_info=True,
             )
             return self.response_422(message=str(ex))
+        except SSHTunnelingNotEnabledError as ex:
+            return self.response_400(message=str(ex))
 
     @expose("/<int:pk>", methods=["DELETE"])
     @protect()
@@ -782,8 +787,11 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
         # This validates custom Schema with custom validations
         except ValidationError as error:
             return self.response_400(message=error.messages)
-        TestConnectionDatabaseCommand(item).run()
-        return self.response(200, message="OK")
+        try:
+            TestConnectionDatabaseCommand(item).run()
+            return self.response(200, message="OK")
+        except SSHTunnelingNotEnabledError as ex:
+            return self.response_400(message=str(ex))
 
     @expose("/<int:pk>/related_objects/", methods=["GET"])
     @protect()
@@ -1320,3 +1328,11 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
                 exc_info=True,
             )
             return self.response_422(message=str(ex))
+        except SSHTunnelingNotEnabledError as ex:
+            logger.error(
+                "Error deleting SSH Tunnel %s: %s",
+                self.__class__.__name__,
+                str(ex),
+                exc_info=True,
+            )
+            return self.response_400(message=str(ex))
diff --git a/superset/databases/commands/create.py b/superset/databases/commands/create.py
index c826d82835..0ed2354960 100644
--- a/superset/databases/commands/create.py
+++ b/superset/databases/commands/create.py
@@ -20,6 +20,7 @@ from typing import Any, Dict, List, Optional
 from flask_appbuilder.models.sqla import Model
 from marshmallow import ValidationError
 
+from superset import is_feature_enabled
 from superset.commands.base import BaseCommand
 from superset.dao.exceptions import DAOCreateFailedError
 from superset.databases.commands.exceptions import (
@@ -34,6 +35,7 @@ from superset.databases.dao import DatabaseDAO
 from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand
 from superset.databases.ssh_tunnel.commands.exceptions import (
     SSHTunnelCreateFailedError,
+    SSHTunnelingNotEnabledError,
     SSHTunnelInvalidError,
 )
 from superset.exceptions import SupersetErrorsException
@@ -52,7 +54,7 @@ class CreateDatabaseCommand(BaseCommand):
         try:
             # Test connection before starting create transaction
             TestConnectionDatabaseCommand(self._properties).run()
-        except SupersetErrorsException as ex:
+        except (SupersetErrorsException, SSHTunnelingNotEnabledError) as ex:
             event_logger.log_with_context(
                 action=f"db_creation_failed.{ex.__class__.__name__}",
                 engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
@@ -78,6 +80,9 @@ class CreateDatabaseCommand(BaseCommand):
 
             ssh_tunnel = None
             if ssh_tunnel_properties := self._properties.get("ssh_tunnel"):
+                if not is_feature_enabled("SSH_TUNNELING"):
+                    db.session.rollback()
+                    raise SSHTunnelingNotEnabledError()
                 try:
                     # So database.id is not None
                     db.session.flush()
diff --git a/superset/databases/commands/test_connection.py b/superset/databases/commands/test_connection.py
index 002adf1236..c5e7dc48f9 100644
--- a/superset/databases/commands/test_connection.py
+++ b/superset/databases/commands/test_connection.py
@@ -25,6 +25,7 @@ from func_timeout import func_timeout, FunctionTimedOut
 from sqlalchemy.engine import Engine
 from sqlalchemy.exc import DBAPIError, NoSuchModuleError
 
+from superset import is_feature_enabled
 from superset.commands.base import BaseCommand
 from superset.databases.commands.exceptions import (
     DatabaseSecurityUnsafeError,
@@ -32,6 +33,9 @@ from superset.databases.commands.exceptions import (
     DatabaseTestConnectionUnexpectedError,
 )
 from superset.databases.dao import DatabaseDAO
+from superset.databases.ssh_tunnel.commands.exceptions import (
+    SSHTunnelingNotEnabledError,
+)
 from superset.databases.ssh_tunnel.dao import SSHTunnelDAO
 from superset.databases.ssh_tunnel.models import SSHTunnel
 from superset.databases.utils import make_url_safe
@@ -64,7 +68,7 @@ class TestConnectionDatabaseCommand(BaseCommand):
         self._properties = data.copy()
         self._model: Optional[Database] = None
 
-    def run(self) -> None:  # pylint: disable=too-many-statements
+    def run(self) -> None:  # pylint: disable=too-many-statements, too-many-branches
         self.validate()
         ex_str = ""
         uri = self._properties.get("sqlalchemy_uri", "")
@@ -107,6 +111,8 @@ class TestConnectionDatabaseCommand(BaseCommand):
 
             # Generate tunnel if present in the properties
             if ssh_tunnel:
+                if not is_feature_enabled("SSH_TUNNELING"):
+                    raise SSHTunnelingNotEnabledError()
                 # If there's an existing tunnel for that DB we need to use the stored
                 # password, private_key and private_key_password instead
                 if ssh_tunnel_id := ssh_tunnel.pop("id", None):
@@ -203,6 +209,15 @@ class TestConnectionDatabaseCommand(BaseCommand):
             )
             # bubble up the exception to return a 408
             raise ex
+        except SSHTunnelingNotEnabledError as ex:
+            event_logger.log_with_context(
+                action=get_log_connection_action(
+                    "test_connection_error", ssh_tunnel, ex
+                ),
+                engine=database.db_engine_spec.__name__,
+            )
+            # bubble up the exception to return a 400
+            raise ex
         except Exception as ex:
             event_logger.log_with_context(
                 action=get_log_connection_action(
diff --git a/superset/databases/commands/update.py b/superset/databases/commands/update.py
index 2e5931788e..0353180355 100644
--- a/superset/databases/commands/update.py
+++ b/superset/databases/commands/update.py
@@ -20,6 +20,7 @@ from typing import Any, Dict, List, Optional
 from flask_appbuilder.models.sqla import Model
 from marshmallow import ValidationError
 
+from superset import is_feature_enabled
 from superset.commands.base import BaseCommand
 from superset.dao.exceptions import DAOCreateFailedError, DAOUpdateFailedError
 from superset.databases.commands.exceptions import (
@@ -33,7 +34,9 @@ from superset.databases.dao import DatabaseDAO
 from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand
 from superset.databases.ssh_tunnel.commands.exceptions import (
     SSHTunnelCreateFailedError,
+    SSHTunnelingNotEnabledError,
     SSHTunnelInvalidError,
+    SSHTunnelUpdateFailedError,
 )
 from superset.databases.ssh_tunnel.commands.update import UpdateSSHTunnelCommand
 from superset.extensions import db, security_manager
@@ -102,6 +105,9 @@ class UpdateDatabaseCommand(BaseCommand):
                 )
 
             if ssh_tunnel_properties := self._properties.get("ssh_tunnel"):
+                if not is_feature_enabled("SSH_TUNNELING"):
+                    db.session.rollback()
+                    raise SSHTunnelingNotEnabledError()
                 existing_ssh_tunnel_model = DatabaseDAO.get_ssh_tunnel(database.id)
                 if existing_ssh_tunnel_model is None:
                     # We couldn't found an existing tunnel so we need to create one
@@ -118,7 +124,7 @@ class UpdateDatabaseCommand(BaseCommand):
                         UpdateSSHTunnelCommand(
                             existing_ssh_tunnel_model.id, ssh_tunnel_properties
                         ).run()
-                    except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex:
+                    except (SSHTunnelInvalidError, SSHTunnelUpdateFailedError) as ex:
                         # So we can show the original message
                         raise ex
                     except Exception as ex:
diff --git a/superset/databases/ssh_tunnel/commands/delete.py b/superset/databases/ssh_tunnel/commands/delete.py
index 3ad2fc2a15..235ceb697b 100644
--- a/superset/databases/ssh_tunnel/commands/delete.py
+++ b/superset/databases/ssh_tunnel/commands/delete.py
@@ -19,10 +19,12 @@ from typing import Optional
 
 from flask_appbuilder.models.sqla import Model
 
+from superset import is_feature_enabled
 from superset.commands.base import BaseCommand
 from superset.dao.exceptions import DAODeleteFailedError
 from superset.databases.ssh_tunnel.commands.exceptions import (
     SSHTunnelDeleteFailedError,
+    SSHTunnelingNotEnabledError,
     SSHTunnelNotFoundError,
 )
 from superset.databases.ssh_tunnel.dao import SSHTunnelDAO
@@ -37,6 +39,8 @@ class DeleteSSHTunnelCommand(BaseCommand):
         self._model: Optional[SSHTunnel] = None
 
     def run(self) -> Model:
+        if not is_feature_enabled("SSH_TUNNELING"):
+            raise SSHTunnelingNotEnabledError()
         self.validate()
         try:
             ssh_tunnel = SSHTunnelDAO.delete(self._model)
diff --git a/superset/databases/ssh_tunnel/commands/exceptions.py b/superset/databases/ssh_tunnel/commands/exceptions.py
index db2d3173de..2495961c36 100644
--- a/superset/databases/ssh_tunnel/commands/exceptions.py
+++ b/superset/databases/ssh_tunnel/commands/exceptions.py
@@ -46,6 +46,11 @@ class SSHTunnelCreateFailedError(CommandException):
     message = _("Creating SSH Tunnel failed for an unknown reason")
 
 
+class SSHTunnelingNotEnabledError(CommandException):
+    status = 400
+    message = _("SSH Tunneling is not enabled")
+
+
 class SSHTunnelRequiredFieldValidationError(ValidationError):
     def __init__(self, field_name: str) -> None:
         super().__init__(
diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py
index ae01ccdaf9..eaa1653847 100644
--- a/tests/integration_tests/databases/api_tests.py
+++ b/tests/integration_tests/databases/api_tests.py
@@ -285,15 +285,20 @@ class TestDatabaseApi(SupersetTestCase):
     @mock.patch(
         "superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
     )
+    @mock.patch("superset.databases.commands.create.is_feature_enabled")
     @mock.patch(
         "superset.models.core.Database.get_all_schema_names",
     )
     def test_create_database_with_ssh_tunnel(
-        self, mock_test_connection_database_command_run, mock_get_all_schema_names
+        self,
+        mock_test_connection_database_command_run,
+        mock_create_is_feature_enabled,
+        mock_get_all_schema_names,
     ):
         """
         Database API: Test create with SSH Tunnel
         """
+        mock_create_is_feature_enabled.return_value = True
         self.login(username="admin")
         example_db = get_example_database()
         if example_db.backend == "sqlite":
@@ -328,15 +333,23 @@ class TestDatabaseApi(SupersetTestCase):
     @mock.patch(
         "superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
     )
+    @mock.patch("superset.databases.commands.create.is_feature_enabled")
+    @mock.patch("superset.databases.commands.update.is_feature_enabled")
     @mock.patch(
         "superset.models.core.Database.get_all_schema_names",
     )
     def test_update_database_with_ssh_tunnel(
-        self, mock_test_connection_database_command_run, mock_get_all_schema_names
+        self,
+        mock_test_connection_database_command_run,
+        mock_create_is_feature_enabled,
+        mock_update_is_feature_enabled,
+        mock_get_all_schema_names,
     ):
         """
-        Database API: Test update with SSH Tunnel
+        Database API: Test update Database with SSH Tunnel
         """
+        mock_create_is_feature_enabled.return_value = True
+        mock_update_is_feature_enabled.return_value = True
         self.login(username="admin")
         example_db = get_example_database()
         if example_db.backend == "sqlite":
@@ -381,15 +394,23 @@ class TestDatabaseApi(SupersetTestCase):
     @mock.patch(
         "superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
     )
+    @mock.patch("superset.databases.commands.create.is_feature_enabled")
+    @mock.patch("superset.databases.commands.update.is_feature_enabled")
     @mock.patch(
         "superset.models.core.Database.get_all_schema_names",
     )
     def test_update_ssh_tunnel_via_database_api(
-        self, mock_test_connection_database_command_run, mock_get_all_schema_names
+        self,
+        mock_test_connection_database_command_run,
+        mock_create_is_feature_enabled,
+        mock_update_is_feature_enabled,
+        mock_get_all_schema_names,
     ):
         """
-        Database API: Test update with SSH Tunnel
+        Database API: Test update SSH Tunnel via Database API
         """
+        mock_create_is_feature_enabled.return_value = True
+        mock_update_is_feature_enabled.return_value = True
         self.login(username="admin")
         example_db = get_example_database()
 
@@ -456,12 +477,17 @@ class TestDatabaseApi(SupersetTestCase):
     @mock.patch(
         "superset.models.core.Database.get_all_schema_names",
     )
+    @mock.patch("superset.databases.commands.create.is_feature_enabled")
     def test_cascade_delete_ssh_tunnel(
-        self, mock_test_connection_database_command_run, mock_get_all_schema_names
+        self,
+        mock_test_connection_database_command_run,
+        mock_get_all_schema_names,
+        mock_create_is_feature_enabled,
     ):
         """
-        Database API: Test create with SSH Tunnel
+        Database API: SSH Tunnel gets deleted if Database gets deleted
         """
+        mock_create_is_feature_enabled.return_value = True
         self.login(username="admin")
         example_db = get_example_database()
         if example_db.backend == "sqlite":
@@ -502,15 +528,20 @@ class TestDatabaseApi(SupersetTestCase):
     @mock.patch(
         "superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
     )
+    @mock.patch("superset.databases.commands.create.is_feature_enabled")
     @mock.patch(
         "superset.models.core.Database.get_all_schema_names",
     )
     def test_do_not_create_database_if_ssh_tunnel_creation_fails(
-        self, mock_test_connection_database_command_run, mock_get_all_schema_names
+        self,
+        mock_test_connection_database_command_run,
+        mock_create_is_feature_enabled,
+        mock_get_all_schema_names,
     ):
         """
-        Database API: Test create with SSH Tunnel
+        Database API: Test Database is not created if SSH Tunnel creation fails
         """
+        mock_create_is_feature_enabled.return_value = True
         self.login(username="admin")
         example_db = get_example_database()
         if example_db.backend == "sqlite":
@@ -548,15 +579,20 @@ class TestDatabaseApi(SupersetTestCase):
     @mock.patch(
         "superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
     )
+    @mock.patch("superset.databases.commands.create.is_feature_enabled")
     @mock.patch(
         "superset.models.core.Database.get_all_schema_names",
     )
     def test_get_database_returns_related_ssh_tunnel(
-        self, mock_test_connection_database_command_run, mock_get_all_schema_names
+        self,
+        mock_test_connection_database_command_run,
+        mock_create_is_feature_enabled,
+        mock_get_all_schema_names,
     ):
         """
         Database API: Test GET Database returns its related SSH Tunnel
         """
+        mock_create_is_feature_enabled.return_value = True
         self.login(username="admin")
         example_db = get_example_database()
         if example_db.backend == "sqlite":
@@ -595,6 +631,56 @@ class TestDatabaseApi(SupersetTestCase):
         db.session.delete(model)
         db.session.commit()
 
+    @mock.patch(
+        "superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
+    )
+    @mock.patch(
+        "superset.models.core.Database.get_all_schema_names",
+    )
+    def test_if_ssh_tunneling_flag_is_not_active_it_raises_new_exception(
+        self,
+        mock_test_connection_database_command_run,
+        mock_get_all_schema_names,
+    ):
+        """
+        Database API: Test raises SSHTunneling feature flag not enabled
+        """
+        self.login(username="admin")
+        example_db = get_example_database()
+        if example_db.backend == "sqlite":
+            return
+        ssh_tunnel_properties = {
+            "server_address": "123.132.123.1",
+            "server_port": 8080,
+            "username": "foo",
+            "password": "bar",
+        }
+        database_data = {
+            "database_name": "test-db-with-ssh-tunnel-7",
+            "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
+            "ssh_tunnel": ssh_tunnel_properties,
+        }
+
+        uri = "api/v1/database/"
+        rv = self.client.post(uri, json=database_data)
+        response = json.loads(rv.data.decode("utf-8"))
+        self.assertEqual(rv.status_code, 400)
+        self.assertEqual(response, {"message": "SSH Tunneling is not enabled"})
+        model_ssh_tunnel = (
+            db.session.query(SSHTunnel)
+            .filter(SSHTunnel.database_id == response.get("id"))
+            .one_or_none()
+        )
+        assert model_ssh_tunnel is None
+        # Cleanup
+        model = (
+            db.session.query(Database)
+            .filter(Database.database_name == "test-db-with-ssh-tunnel-7")
+            .one_or_none()
+        )
+        # the DB should not be created
+        assert model is None
+
     def test_create_database_invalid_configuration_method(self):
         """
         Database API: Test create with an invalid configuration method.
diff --git a/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py b/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py
index 75e5a55e86..86c280b9bb 100644
--- a/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py
+++ b/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py
@@ -67,8 +67,10 @@ class TestUpdateSSHTunnelCommand(SupersetTestCase):
 
 class TestDeleteSSHTunnelCommand(SupersetTestCase):
     @mock.patch("superset.utils.core.g")
-    def test_delete_ssh_tunnel_not_found(self, mock_g):
+    @mock.patch("superset.databases.ssh_tunnel.commands.delete.is_feature_enabled")
+    def test_delete_ssh_tunnel_not_found(self, mock_g, mock_delete_is_feature_enabled):
         mock_g.user = security_manager.find_user("admin")
+        mock_delete_is_feature_enabled.return_value = True
         # We have not created a SSH Tunnel yet so id = 1 is invalid
         command = DeleteSSHTunnelCommand(1)
         with pytest.raises(SSHTunnelNotFoundError) as excinfo:
diff --git a/tests/unit_tests/databases/api_test.py b/tests/unit_tests/databases/api_test.py
index fe4211289c..68a9add12e 100644
--- a/tests/unit_tests/databases/api_test.py
+++ b/tests/unit_tests/databases/api_test.py
@@ -241,6 +241,10 @@ def test_delete_ssh_tunnel(
         # mock the lookup so that we don't need to include the driver
         mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
         mocker.patch("superset.utils.log.DBEventLogger.log")
+        mocker.patch(
+            "superset.databases.ssh_tunnel.commands.delete.is_feature_enabled",
+            return_value=True,
+        )
 
         # Create our SSHTunnel
         tunnel = SSHTunnel(
@@ -313,6 +317,10 @@ def test_delete_ssh_tunnel_not_found(
         # mock the lookup so that we don't need to include the driver
         mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
         mocker.patch("superset.utils.log.DBEventLogger.log")
+        mocker.patch(
+            "superset.databases.ssh_tunnel.commands.delete.is_feature_enabled",
+            return_value=True,
+        )
 
         # Create our SSHTunnel
         tunnel = SSHTunnel(
diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py
index 17afebfa0f..b5adf765fa 100644
--- a/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py
+++ b/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py
@@ -18,6 +18,7 @@
 from typing import Iterator
 
 import pytest
+from pytest_mock import MockFixture
 from sqlalchemy.orm.session import Session
 
 
@@ -50,7 +51,9 @@ def session_with_data(session: Session) -> Iterator[Session]:
     session.rollback()
 
 
-def test_delete_ssh_tunnel_command(session_with_data: Session) -> None:
+def test_delete_ssh_tunnel_command(
+    mocker: MockFixture, session_with_data: Session
+) -> None:
     from superset.databases.dao import DatabaseDAO
     from superset.databases.ssh_tunnel.commands.delete import DeleteSSHTunnelCommand
     from superset.databases.ssh_tunnel.models import SSHTunnel
@@ -60,9 +63,11 @@ def test_delete_ssh_tunnel_command(session_with_data: Session) -> None:
     assert result
     assert isinstance(result, SSHTunnel)
     assert 1 == result.database_id
-
+    mocker.patch(
+        "superset.databases.ssh_tunnel.commands.delete.is_feature_enabled",
+        return_value=True,
+    )
     DeleteSSHTunnelCommand(1).run()
-
     result = DatabaseDAO.get_ssh_tunnel(1)
 
     assert result is None