You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by di...@apache.org on 2024/02/21 10:13:28 UTC

(superset) 06/10: Catch missing database port for SSH Tunnel

This is an automated email from the ASF dual-hosted git repository.

diegopucci pushed a commit to branch diego/ch78628/fix-disabled-ssh-toggle
in repository https://gitbox.apache.org/repos/asf/superset.git

commit b94994f07fd3a121cd17536dc610a3d48bc325ba
Author: geido <di...@gmail.com>
AuthorDate: Tue Feb 20 13:23:12 2024 +0200

    Catch missing database port for SSH Tunnel
---
 superset/commands/database/create.py               |  4 ++
 superset/commands/database/ssh_tunnel/create.py    |  8 ++++
 .../commands/database/ssh_tunnel/exceptions.py     |  4 ++
 superset/commands/database/ssh_tunnel/update.py    |  6 +++
 superset/commands/database/test_connection.py      | 52 ++++++++++++----------
 superset/commands/database/update.py               | 33 ++++++++------
 superset/databases/api.py                          |  5 ++-
 7 files changed, 73 insertions(+), 39 deletions(-)

diff --git a/superset/commands/database/create.py b/superset/commands/database/create.py
index cde9dd8e88..1ddc08e6a1 100644
--- a/superset/commands/database/create.py
+++ b/superset/commands/database/create.py
@@ -19,6 +19,7 @@ from typing import Any, Optional
 
 from flask import current_app
 from flask_appbuilder.models.sqla import Model
+from flask_babel import gettext as _
 from marshmallow import ValidationError
 
 from superset import is_feature_enabled
@@ -33,6 +34,7 @@ from superset.commands.database.exceptions import (
 from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
 from superset.commands.database.ssh_tunnel.exceptions import (
     SSHTunnelCreateFailedError,
+    SSHTunnelDatabasePortError,
     SSHTunnelingNotEnabledError,
     SSHTunnelInvalidError,
 )
@@ -103,6 +105,7 @@ class CreateDatabaseCommand(BaseCommand):
             SSHTunnelInvalidError,
             SSHTunnelCreateFailedError,
             SSHTunnelingNotEnabledError,
+            SSHTunnelDatabasePortError,
         ) as ex:
             db.session.rollback()
             event_logger.log_with_context(
@@ -140,6 +143,7 @@ class CreateDatabaseCommand(BaseCommand):
             # Check database_name uniqueness
             if not DatabaseDAO.validate_uniqueness(database_name):
                 exceptions.append(DatabaseExistsValidationError())
+
         if exceptions:
             exception = DatabaseInvalidError()
             exception.extend(exceptions)
diff --git a/superset/commands/database/ssh_tunnel/create.py b/superset/commands/database/ssh_tunnel/create.py
index 59e083d4d8..287accc5aa 100644
--- a/superset/commands/database/ssh_tunnel/create.py
+++ b/superset/commands/database/ssh_tunnel/create.py
@@ -23,11 +23,13 @@ from marshmallow import ValidationError
 from superset.commands.base import BaseCommand
 from superset.commands.database.ssh_tunnel.exceptions import (
     SSHTunnelCreateFailedError,
+    SSHTunnelDatabasePortError,
     SSHTunnelInvalidError,
     SSHTunnelRequiredFieldValidationError,
 )
 from superset.daos.database import SSHTunnelDAO
 from superset.daos.exceptions import DAOCreateFailedError
+from superset.databases.utils import make_url_safe
 from superset.extensions import event_logger
 from superset.models.core import Database
 
@@ -35,9 +37,12 @@ logger = logging.getLogger(__name__)
 
 
 class CreateSSHTunnelCommand(BaseCommand):
+    _database: Database
+
     def __init__(self, database: Database, data: dict[str, Any]):
         self._properties = data.copy()
         self._properties["database"] = database
+        self._database = database
 
     def run(self) -> Model:
         try:
@@ -62,6 +67,9 @@ class CreateSSHTunnelCommand(BaseCommand):
         private_key_password: Optional[str] = self._properties.get(
             "private_key_password"
         )
+        url = make_url_safe(self._database.sqlalchemy_uri)
+        if not url.port:
+            raise SSHTunnelDatabasePortError()
         if not server_address:
             exceptions.append(SSHTunnelRequiredFieldValidationError("server_address"))
         if not server_port:
diff --git a/superset/commands/database/ssh_tunnel/exceptions.py b/superset/commands/database/ssh_tunnel/exceptions.py
index 0e3f91cae6..a0def8c087 100644
--- a/superset/commands/database/ssh_tunnel/exceptions.py
+++ b/superset/commands/database/ssh_tunnel/exceptions.py
@@ -38,6 +38,10 @@ class SSHTunnelInvalidError(CommandInvalidError):
     message = _("SSH Tunnel parameters are invalid.")
 
 
+class SSHTunnelDatabasePortError(CommandInvalidError):
+    message = _("A database port is required when connecting via SSH Tunnel.")
+
+
 class SSHTunnelUpdateFailedError(UpdateFailedError):
     message = _("SSH Tunnel could not be updated.")
 
diff --git a/superset/commands/database/ssh_tunnel/update.py b/superset/commands/database/ssh_tunnel/update.py
index 47f7d4947a..077ed4c321 100644
--- a/superset/commands/database/ssh_tunnel/update.py
+++ b/superset/commands/database/ssh_tunnel/update.py
@@ -21,6 +21,7 @@ from flask_appbuilder.models.sqla import Model
 
 from superset.commands.base import BaseCommand
 from superset.commands.database.ssh_tunnel.exceptions import (
+    SSHTunnelDatabasePortError,
     SSHTunnelInvalidError,
     SSHTunnelNotFoundError,
     SSHTunnelRequiredFieldValidationError,
@@ -29,6 +30,7 @@ from superset.commands.database.ssh_tunnel.exceptions import (
 from superset.daos.database import SSHTunnelDAO
 from superset.daos.exceptions import DAOUpdateFailedError
 from superset.databases.ssh_tunnel.models import SSHTunnel
+from superset.databases.utils import make_url_safe
 
 logger = logging.getLogger(__name__)
 
@@ -62,6 +64,8 @@ class UpdateSSHTunnelCommand(BaseCommand):
         self._model = SSHTunnelDAO.find_by_id(self._model_id)
         if not self._model:
             raise SSHTunnelNotFoundError()
+
+        url = make_url_safe(self._model.database.sqlalchemy_uri)
         private_key: Optional[str] = self._properties.get("private_key")
         private_key_password: Optional[str] = self._properties.get(
             "private_key_password"
@@ -70,3 +74,5 @@ class UpdateSSHTunnelCommand(BaseCommand):
             raise SSHTunnelInvalidError(
                 exceptions=[SSHTunnelRequiredFieldValidationError("private_key")]
             )
+        if not url.port:
+            raise SSHTunnelDatabasePortError()
diff --git a/superset/commands/database/test_connection.py b/superset/commands/database/test_connection.py
index 0ffdf3ddd9..e91eec3a89 100644
--- a/superset/commands/database/test_connection.py
+++ b/superset/commands/database/test_connection.py
@@ -32,8 +32,11 @@ from superset.commands.database.exceptions import (
     DatabaseTestConnectionDriverError,
     DatabaseTestConnectionUnexpectedError,
 )
-from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelingNotEnabledError
-from superset.daos.database import DatabaseDAO, SSHTunnelDAO
+from superset.commands.database.ssh_tunnel.exceptions import (
+    SSHTunnelDatabasePortError,
+    SSHTunnelingNotEnabledError,
+)
+from superset.daos.database import DatabaseDAO
 from superset.databases.ssh_tunnel.models import SSHTunnel
 from superset.databases.utils import make_url_safe
 from superset.errors import ErrorLevel, SupersetErrorType
@@ -44,7 +47,6 @@ from superset.exceptions import (
 )
 from superset.extensions import event_logger
 from superset.models.core import Database
-from superset.utils.ssh_tunnel import unmask_password_info
 
 logger = logging.getLogger(__name__)
 
@@ -61,20 +63,22 @@ def get_log_connection_action(
 
 
 class TestConnectionDatabaseCommand(BaseCommand):
+    _model: Optional[Database] = None
+    _context: dict[str, Any]
+    _uri: str
+
     def __init__(self, data: dict[str, Any]):
         self._properties = data.copy()
-        self._model: Optional[Database] = None
 
-    def run(self) -> None:  # pylint: disable=too-many-statements, too-many-branches
-        self.validate()
-        ex_str = ""
+        if (database_name := self._properties.get("database_name")) is not None:
+            self._model = DatabaseDAO.get_database_by_name(database_name)
+
         uri = self._properties.get("sqlalchemy_uri", "")
         if self._model and uri == self._model.safe_sqlalchemy_uri():
             uri = self._model.sqlalchemy_uri_decrypted
-        ssh_tunnel = self._properties.get("ssh_tunnel")
 
-        # context for error messages
         url = make_url_safe(uri)
+
         context = {
             "hostname": url.host,
             "password": url.password,
@@ -83,6 +87,14 @@ class TestConnectionDatabaseCommand(BaseCommand):
             "database": url.database,
         }
 
+        self._context = context
+        self._uri = uri
+
+    def run(self) -> None:  # pylint: disable=too-many-statements, too-many-branches
+        self.validate()
+        ex_str = ""
+        ssh_tunnel = self._properties.get("ssh_tunnel")
+
         serialized_encrypted_extra = self._properties.get(
             "masked_encrypted_extra",
             "{}",
@@ -103,20 +115,11 @@ class TestConnectionDatabaseCommand(BaseCommand):
                 encrypted_extra=serialized_encrypted_extra,
             )
 
-            database.set_sqlalchemy_uri(uri)
+            database.set_sqlalchemy_uri(self._uri)
             database.db_engine_spec.mutate_db_for_connection_test(database)
 
             # Generate tunnel if present in the properties
             if ssh_tunnel:
-                if not is_feature_enabled("SSH_TUNNELING"):
-                    raise SSHTunnelingNotEnabledError()
-                # If there's an existing tunnel for that DB we need to use the stored
-                # password, private_key and private_key_password instead
-                if ssh_tunnel_id := ssh_tunnel.pop("id", None):
-                    if existing_ssh_tunnel := SSHTunnelDAO.find_by_id(ssh_tunnel_id):
-                        ssh_tunnel = unmask_password_info(
-                            ssh_tunnel, existing_ssh_tunnel
-                        )
                 ssh_tunnel = SSHTunnel(**ssh_tunnel)
 
             event_logger.log_with_context(
@@ -186,7 +189,7 @@ class TestConnectionDatabaseCommand(BaseCommand):
                 engine=database.db_engine_spec.__name__,
             )
             # check for custom errors (wrong username, wrong password, etc)
-            errors = database.db_engine_spec.extract_errors(ex, context)
+            errors = database.db_engine_spec.extract_errors(ex, self._context)
             raise SupersetErrorsException(errors) from ex
         except SupersetSecurityException as ex:
             event_logger.log_with_context(
@@ -221,9 +224,12 @@ class TestConnectionDatabaseCommand(BaseCommand):
                 ),
                 engine=database.db_engine_spec.__name__,
             )
-            errors = database.db_engine_spec.extract_errors(ex, context)
+            errors = database.db_engine_spec.extract_errors(ex, self._context)
             raise DatabaseTestConnectionUnexpectedError(errors) from ex
 
     def validate(self) -> None:
-        if (database_name := self._properties.get("database_name")) is not None:
-            self._model = DatabaseDAO.get_database_by_name(database_name)
+        if self._properties.get("ssh_tunnel"):
+            if not is_feature_enabled("SSH_TUNNELING"):
+                raise SSHTunnelingNotEnabledError()
+            if not self._context.get("port"):
+                raise SSHTunnelDatabasePortError()
diff --git a/superset/commands/database/update.py b/superset/commands/database/update.py
index b891c8f157..88539a2c7b 100644
--- a/superset/commands/database/update.py
+++ b/superset/commands/database/update.py
@@ -18,6 +18,7 @@ import logging
 from typing import Any, Optional
 
 from flask_appbuilder.models.sqla import Model
+from flask_babel import gettext as _
 from marshmallow import ValidationError
 
 from superset import is_feature_enabled
@@ -33,6 +34,7 @@ from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
 from superset.commands.database.ssh_tunnel.delete import DeleteSSHTunnelCommand
 from superset.commands.database.ssh_tunnel.exceptions import (
     SSHTunnelCreateFailedError,
+    SSHTunnelDatabasePortError,
     SSHTunnelDeleteFailedError,
     SSHTunnelingNotEnabledError,
     SSHTunnelInvalidError,
@@ -49,15 +51,19 @@ logger = logging.getLogger(__name__)
 
 
 class UpdateDatabaseCommand(BaseCommand):
+    _model: Optional[Database]
+
     def __init__(self, model_id: int, data: dict[str, Any]):
         self._properties = data.copy()
         self._model_id = model_id
-        self._model: Optional[Database] = None
+        self._model = DatabaseDAO.find_by_id(self._model_id)
 
     def run(self) -> Model:
-        self.validate()
         if not self._model:
             raise DatabaseNotFoundError()
+
+        self.validate()
+
         old_database_name = self._model.database_name
 
         # unmask ``encrypted_extra``
@@ -72,32 +78,34 @@ class UpdateDatabaseCommand(BaseCommand):
             database = DatabaseDAO.update(self._model, self._properties, commit=False)
             database.set_sqlalchemy_uri(database.sqlalchemy_uri)
 
-            existing_ssh_tunnel_model = DatabaseDAO.get_ssh_tunnel(database.id)
+            ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
 
             if "ssh_tunnel" in self._properties:
                 if not is_feature_enabled("SSH_TUNNELING"):
                     db.session.rollback()
                     raise SSHTunnelingNotEnabledError()
 
-                if not self._properties.get("ssh_tunnel") and existing_ssh_tunnel_model:
+                if not self._properties.get("ssh_tunnel") and ssh_tunnel:
                     # We need to remove the existing tunnel
                     try:
-                        DeleteSSHTunnelCommand(existing_ssh_tunnel_model.id).run()
+                        DeleteSSHTunnelCommand(ssh_tunnel.id).run()
+                        ssh_tunnel = None
                     except SSHTunnelDeleteFailedError as ex:
                         raise ex
                     except Exception as ex:
                         raise DatabaseUpdateFailedError() from ex
 
                 if ssh_tunnel_properties := self._properties.get("ssh_tunnel"):
-                    if existing_ssh_tunnel_model is None:
+                    if ssh_tunnel is None:
                         # We couldn't found an existing tunnel so we need to create one
                         try:
-                            CreateSSHTunnelCommand(
+                            ssh_tunnel = CreateSSHTunnelCommand(
                                 database, ssh_tunnel_properties
                             ).run()
                         except (
                             SSHTunnelInvalidError,
                             SSHTunnelCreateFailedError,
+                            SSHTunnelDatabasePortError,
                         ) as ex:
                             # So we can show the original message
                             raise ex
@@ -106,12 +114,14 @@ class UpdateDatabaseCommand(BaseCommand):
                     else:
                         # We found an existing tunnel so we need to update it
                         try:
-                            UpdateSSHTunnelCommand(
-                                existing_ssh_tunnel_model.id, ssh_tunnel_properties
+                            ssh_tunnel_id = ssh_tunnel.id
+                            ssh_tunnel = UpdateSSHTunnelCommand(
+                                ssh_tunnel_id, ssh_tunnel_properties
                             ).run()
                         except (
                             SSHTunnelInvalidError,
                             SSHTunnelUpdateFailedError,
+                            SSHTunnelDatabasePortError,
                         ) as ex:
                             # So we can show the original message
                             raise ex
@@ -121,7 +131,6 @@ class UpdateDatabaseCommand(BaseCommand):
             # adding a new database we always want to force refresh schema list
             # TODO Improve this simplistic implementation for catching DB conn fails
             try:
-                ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
                 schemas = database.get_all_schema_names(ssh_tunnel=ssh_tunnel)
             except Exception as ex:
                 db.session.rollback()
@@ -189,10 +198,6 @@ class UpdateDatabaseCommand(BaseCommand):
 
     def validate(self) -> None:
         exceptions: list[ValidationError] = []
-        # Validate/populate model exists
-        self._model = DatabaseDAO.find_by_id(self._model_id)
-        if not self._model:
-            raise DatabaseNotFoundError()
         database_name: Optional[str] = self._properties.get("database_name")
         if database_name:
             # Check database_name uniqueness
diff --git a/superset/databases/api.py b/superset/databases/api.py
index 2f95bd0442..e6aca61a20 100644
--- a/superset/databases/api.py
+++ b/superset/databases/api.py
@@ -47,6 +47,7 @@ from superset.commands.database.export import ExportDatabasesCommand
 from superset.commands.database.importers.dispatcher import ImportDatabasesCommand
 from superset.commands.database.ssh_tunnel.delete import DeleteSSHTunnelCommand
 from superset.commands.database.ssh_tunnel.exceptions import (
+    SSHTunnelDatabasePortError,
     SSHTunnelDeleteFailedError,
     SSHTunnelingNotEnabledError,
 )
@@ -415,7 +416,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
                 exc_info=True,
             )
             return self.response_422(message=str(ex))
-        except SSHTunnelingNotEnabledError as ex:
+        except (SSHTunnelingNotEnabledError, SSHTunnelDatabasePortError) as ex:
             return self.response_400(message=str(ex))
         except SupersetException as ex:
             return self.response(ex.status, message=ex.message)
@@ -500,7 +501,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
                 exc_info=True,
             )
             return self.response_422(message=str(ex))
-        except SSHTunnelingNotEnabledError as ex:
+        except (SSHTunnelingNotEnabledError, SSHTunnelDatabasePortError) as ex:
             return self.response_400(message=str(ex))
 
     @expose("/<int:pk>", methods=("DELETE",))