You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by hu...@apache.org on 2022/12/06 21:01:12 UTC

[superset] 01/01: refactor bind_host and bind_port

This is an automated email from the ASF dual-hosted git repository.

hugh pushed a commit to branch create-sshtunnelconfig-tbl
in repository https://gitbox.apache.org/repos/asf/superset.git

commit ec20429a84feba863d762e4160c14d516951d40b
Author: hughhhh <hu...@gmail.com>
AuthorDate: Tue Dec 6 15:10:37 2022 -0500

    refactor bind_host and bind_port
---
 superset/databases/commands/test_connection.py     |  3 ---
 superset/databases/ssh_tunnel/models.py            |  7 ++----
 ...c2d8ec8595_create_ssh_tunnel_credentials_tbl.py |  2 --
 superset/models/core.py                            | 25 +++++++++++-----------
 4 files changed, 14 insertions(+), 23 deletions(-)

diff --git a/superset/databases/commands/test_connection.py b/superset/databases/commands/test_connection.py
index ac1d12b78c..8027efcb49 100644
--- a/superset/databases/commands/test_connection.py
+++ b/superset/databases/commands/test_connection.py
@@ -93,9 +93,6 @@ class TestConnectionDatabaseCommand(BaseCommand):
 
             # Generate tunnel if present in the properties
             if ssh_tunnel := self._properties.get("ssh_tunnel"):
-                url = make_url_safe(database.sqlalchemy_uri_decrypted)
-                ssh_tunnel["bind_host"] = url.host
-                ssh_tunnel["bind_port"] = url.port
                 ssh_tunnel = SSHTunnel(**ssh_tunnel)
 
             event_logger.log_with_context(
diff --git a/superset/databases/ssh_tunnel/models.py b/superset/databases/ssh_tunnel/models.py
index f3bcd303d9..d4ca3504cd 100644
--- a/superset/databases/ssh_tunnel/models.py
+++ b/superset/databases/ssh_tunnel/models.py
@@ -68,15 +68,12 @@ class SSHTunnel(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
         EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True
     )
 
-    bind_host = sa.Column(sa.Text)
-    bind_port = sa.Column(sa.Integer)
-
-    def parameters(self) -> Dict[str, Any]:
+    def parameters(self, bind_host: str, bind_port: int) -> Dict[str, Any]:
         params = {
             "ssh_address_or_host": self.server_address,
             "ssh_port": self.server_port,
             "ssh_username": self.username,
-            "remote_bind_address": (self.bind_host, self.bind_port),
+            "remote_bind_address": (bind_host, bind_port),
             "local_bind_address": (SSH_TUNNELLING_LOCAL_BIND_ADDRESS,),
         }
 
diff --git a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py
index b90ccae50f..75bad1e53e 100644
--- a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py
+++ b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py
@@ -69,8 +69,6 @@ def upgrade():
             encrypted_field_factory.create(sa.String(256)),
             nullable=True,
         ),
-        sa.Column("bind_host", sa.String(256)),
-        sa.Column("bind_port", sa.INTEGER()),
     )
 
 
diff --git a/superset/models/core.py b/superset/models/core.py
index 309c4444a6..77d9170c42 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -21,7 +21,7 @@ import json
 import logging
 import textwrap
 from ast import literal_eval
-from contextlib import closing, contextmanager
+from contextlib import closing, contextmanager, nullcontext
 from copy import deepcopy
 from datetime import datetime
 from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TYPE_CHECKING
@@ -384,19 +384,18 @@ class Database(
         ):
             # if ssh_tunnel is available build engine with information
             url = make_url_safe(self.sqlalchemy_uri_decrypted)
-            ssh_tunnel.bind_host = url.host
-            ssh_tunnel.bind_port = url.port
-            ssh_params = ssh_tunnel.parameters()
-            with sshtunnel.open_tunnel(**ssh_params) as server:
-                yield self._get_sqla_engine(
-                    schema=schema,
-                    nullpool=nullpool,
-                    source=source,
-                    ssh_tunnel_server=server,
-                )
-
+            ssh_params = ssh_tunnel.parameters(bind_host=url.host, bind_port=url.port)
+            engine_context = sshtunnel.open_tunnel(**ssh_params)
         else:
-            yield self._get_sqla_engine(schema=schema, nullpool=nullpool, source=source)
+            engine_context = nullcontext()
+
+        with engine_context as server_context:
+            yield self._get_sqla_engine(
+                schema=schema,
+                nullpool=nullpool,
+                source=source,
+                ssh_tunnel_server=server_context,
+            )
 
     def _get_sqla_engine(
         self,