You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by hu...@apache.org on 2022/12/06 21:01:12 UTC
[superset] 01/01: refactor bind_host and bind_port
This is an automated email from the ASF dual-hosted git repository.
hugh pushed a commit to branch create-sshtunnelconfig-tbl
in repository https://gitbox.apache.org/repos/asf/superset.git
commit ec20429a84feba863d762e4160c14d516951d40b
Author: hughhhh <hu...@gmail.com>
AuthorDate: Tue Dec 6 15:10:37 2022 -0500
refactor bind_host and bind_port
---
superset/databases/commands/test_connection.py | 3 ---
superset/databases/ssh_tunnel/models.py | 7 ++----
...c2d8ec8595_create_ssh_tunnel_credentials_tbl.py | 2 --
superset/models/core.py | 25 +++++++++++-----------
4 files changed, 14 insertions(+), 23 deletions(-)
diff --git a/superset/databases/commands/test_connection.py b/superset/databases/commands/test_connection.py
index ac1d12b78c..8027efcb49 100644
--- a/superset/databases/commands/test_connection.py
+++ b/superset/databases/commands/test_connection.py
@@ -93,9 +93,6 @@ class TestConnectionDatabaseCommand(BaseCommand):
# Generate tunnel if present in the properties
if ssh_tunnel := self._properties.get("ssh_tunnel"):
- url = make_url_safe(database.sqlalchemy_uri_decrypted)
- ssh_tunnel["bind_host"] = url.host
- ssh_tunnel["bind_port"] = url.port
ssh_tunnel = SSHTunnel(**ssh_tunnel)
event_logger.log_with_context(
diff --git a/superset/databases/ssh_tunnel/models.py b/superset/databases/ssh_tunnel/models.py
index f3bcd303d9..d4ca3504cd 100644
--- a/superset/databases/ssh_tunnel/models.py
+++ b/superset/databases/ssh_tunnel/models.py
@@ -68,15 +68,12 @@ class SSHTunnel(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True
)
- bind_host = sa.Column(sa.Text)
- bind_port = sa.Column(sa.Integer)
-
- def parameters(self) -> Dict[str, Any]:
+ def parameters(self, bind_host: str, bind_port: int) -> Dict[str, Any]:
params = {
"ssh_address_or_host": self.server_address,
"ssh_port": self.server_port,
"ssh_username": self.username,
- "remote_bind_address": (self.bind_host, self.bind_port),
+ "remote_bind_address": (bind_host, bind_port),
"local_bind_address": (SSH_TUNNELLING_LOCAL_BIND_ADDRESS,),
}
diff --git a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py
index b90ccae50f..75bad1e53e 100644
--- a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py
+++ b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py
@@ -69,8 +69,6 @@ def upgrade():
encrypted_field_factory.create(sa.String(256)),
nullable=True,
),
- sa.Column("bind_host", sa.String(256)),
- sa.Column("bind_port", sa.INTEGER()),
)
diff --git a/superset/models/core.py b/superset/models/core.py
index 309c4444a6..77d9170c42 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -21,7 +21,7 @@ import json
import logging
import textwrap
from ast import literal_eval
-from contextlib import closing, contextmanager
+from contextlib import closing, contextmanager, nullcontext
from copy import deepcopy
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TYPE_CHECKING
@@ -384,19 +384,18 @@ class Database(
):
# if ssh_tunnel is available build engine with information
url = make_url_safe(self.sqlalchemy_uri_decrypted)
- ssh_tunnel.bind_host = url.host
- ssh_tunnel.bind_port = url.port
- ssh_params = ssh_tunnel.parameters()
- with sshtunnel.open_tunnel(**ssh_params) as server:
- yield self._get_sqla_engine(
- schema=schema,
- nullpool=nullpool,
- source=source,
- ssh_tunnel_server=server,
- )
-
+ ssh_params = ssh_tunnel.parameters(bind_host=url.host, bind_port=url.port)
+ engine_context = sshtunnel.open_tunnel(**ssh_params)
else:
- yield self._get_sqla_engine(schema=schema, nullpool=nullpool, source=source)
+ engine_context = nullcontext()
+
+ with engine_context as server_context:
+ yield self._get_sqla_engine(
+ schema=schema,
+ nullpool=nullpool,
+ source=source,
+ ssh_tunnel_server=server_context,
+ )
def _get_sqla_engine(
self,