You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by mi...@apache.org on 2024/03/14 12:35:19 UTC
(superset) 03/05: fix: SSH Tunnel configuration settings (#27186)
This is an automated email from the ASF dual-hosted git repository.
michaelsmolina pushed a commit to branch 4.0
in repository https://gitbox.apache.org/repos/asf/superset.git
commit 131c254fe786306d4f79fc59cd42dc5a6efd0d54
Author: Geido <60...@users.noreply.github.com>
AuthorDate: Mon Mar 11 16:56:54 2024 +0100
fix: SSH Tunnel configuration settings (#27186)
(cherry picked from commit 89e89de341c555a1fdbe9d3f5bccada58eb08059)
---
.../superset-ui-core/src/ui-overrides/types.ts | 18 +-
.../src/features/alerts/AlertReportModal.test.tsx | 4 +-
.../DatabaseConnectionForm/CommonParameters.tsx | 35 +---
.../DatabaseConnectionForm/EncryptedField.tsx | 2 +-
.../DatabaseConnectionForm/TableCatalog.tsx | 3 +-
.../DatabaseConnectionForm/ValidatedInputField.tsx | 2 +-
.../DatabaseModal/DatabaseConnectionForm/index.tsx | 130 ++++++-------
.../databases/DatabaseModal/SSHTunnelForm.tsx | 12 +-
.../DatabaseModal/SSHTunnelSwitch.test.tsx | 162 +++++++++++++++++
.../databases/DatabaseModal/SSHTunnelSwitch.tsx | 82 ++++++---
.../databases/DatabaseModal/index.test.tsx | 11 +-
.../src/features/databases/DatabaseModal/index.tsx | 132 ++++++++------
superset-frontend/src/features/databases/types.ts | 80 +++++++-
superset-frontend/src/views/CRUD/hooks.ts | 7 +-
superset/commands/database/create.py | 10 +-
superset/commands/database/ssh_tunnel/create.py | 11 ++
.../commands/database/ssh_tunnel/exceptions.py | 4 +
superset/commands/database/ssh_tunnel/update.py | 25 ++-
superset/commands/database/test_connection.py | 45 +++--
superset/commands/database/update.py | 79 +++++---
superset/databases/api.py | 7 +-
tests/integration_tests/databases/api_tests.py | 201 +++++++++++++++++++++
.../databases/ssh_tunnel/commands/create_test.py | 45 ++++-
.../databases/ssh_tunnel/commands/update_test.py | 35 +++-
24 files changed, 871 insertions(+), 271 deletions(-)
diff --git a/superset-frontend/packages/superset-ui-core/src/ui-overrides/types.ts b/superset-frontend/packages/superset-ui-core/src/ui-overrides/types.ts
index 45ec06e90e..60598bd4e1 100644
--- a/superset-frontend/packages/superset-ui-core/src/ui-overrides/types.ts
+++ b/superset-frontend/packages/superset-ui-core/src/ui-overrides/types.ts
@@ -44,15 +44,15 @@ interface MenuObjectChildProps {
disable?: boolean;
}
-export interface SwitchProps {
- isEditMode: boolean;
- dbFetched: any;
- disableSSHTunnelingForEngine?: boolean;
- useSSHTunneling: boolean;
- setUseSSHTunneling: React.Dispatch<React.SetStateAction<boolean>>;
- setDB: React.Dispatch<any>;
- isSSHTunneling: boolean;
-}
+// loose typing to avoid any circular dependencies
+// refer to SSHTunnelSwitch component for strict typing
+type SwitchProps = {
+ db: object;
+ changeMethods: {
+ onParametersChange: (event: any) => void;
+ };
+ clearValidationErrors: () => void;
+};
type ConfigDetailsProps = {
embeddedId: string;
diff --git a/superset-frontend/src/features/alerts/AlertReportModal.test.tsx b/superset-frontend/src/features/alerts/AlertReportModal.test.tsx
index ee9504286d..358aa27df3 100644
--- a/superset-frontend/src/features/alerts/AlertReportModal.test.tsx
+++ b/superset-frontend/src/features/alerts/AlertReportModal.test.tsx
@@ -541,8 +541,8 @@ test('defaults to day when CRON is not selected', async () => {
useRedux: true,
});
userEvent.click(screen.getByTestId('schedule-panel'));
- const days = screen.getAllByTitle(/day/i, { exact: true });
- expect(days.length).toBe(2);
+ const day = screen.getByText('day');
+ expect(day).toBeInTheDocument();
});
// Notification Method Section
diff --git a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/CommonParameters.tsx b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/CommonParameters.tsx
index 7b52eab26c..3f1f5f9625 100644
--- a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/CommonParameters.tsx
+++ b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/CommonParameters.tsx
@@ -17,12 +17,11 @@
* under the License.
*/
import React from 'react';
-import { isEmpty } from 'lodash';
import { SupersetTheme, t } from '@superset-ui/core';
import { AntdSwitch } from 'src/components';
import InfoTooltip from 'src/components/InfoTooltip';
import ValidatedInput from 'src/components/Form/LabeledErrorBoundInput';
-import { FieldPropTypes } from '.';
+import { FieldPropTypes } from '../../types';
import { toggleStyle, infoTooltip } from '../styles';
export const hostField = ({
@@ -252,35 +251,3 @@ export const forceSSLField = ({
/>
</div>
);
-
-export const SSHTunnelSwitch = ({
- isEditMode,
- changeMethods,
- clearValidationErrors,
- db,
-}: FieldPropTypes) => (
- <div css={(theme: SupersetTheme) => infoTooltip(theme)}>
- <AntdSwitch
- disabled={isEditMode && !isEmpty(db?.ssh_tunnel)}
- checked={db?.parameters?.ssh}
- onChange={changed => {
- changeMethods.onParametersChange({
- target: {
- type: 'toggle',
- name: 'ssh',
- checked: true,
- value: changed,
- },
- });
- clearValidationErrors();
- }}
- data-test="ssh-tunnel-switch"
- />
- <span css={toggleStyle}>{t('SSH Tunnel')}</span>
- <InfoTooltip
- tooltip={t('SSH Tunnel configuration parameters')}
- placement="right"
- viewBox="0 -5 24 24"
- />
- </div>
-);
diff --git a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/EncryptedField.tsx b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/EncryptedField.tsx
index c5e268e569..009afc84ef 100644
--- a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/EncryptedField.tsx
+++ b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/EncryptedField.tsx
@@ -22,7 +22,7 @@ import { AntdButton, AntdSelect } from 'src/components';
import InfoTooltip from 'src/components/InfoTooltip';
import FormLabel from 'src/components/Form/FormLabel';
import Icons from 'src/components/Icons';
-import { FieldPropTypes } from '.';
+import { FieldPropTypes } from '../../types';
import { infoTooltip, labelMarginBottom, CredentialInfoForm } from '../styles';
enum CredentialInfoOptions {
diff --git a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/TableCatalog.tsx b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/TableCatalog.tsx
index ed5cc94903..47a0ec1579 100644
--- a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/TableCatalog.tsx
+++ b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/TableCatalog.tsx
@@ -21,9 +21,8 @@ import { css, SupersetTheme, t } from '@superset-ui/core';
import ValidatedInput from 'src/components/Form/LabeledErrorBoundInput';
import FormLabel from 'src/components/Form/FormLabel';
import Icons from 'src/components/Icons';
-import { FieldPropTypes } from '.';
import { StyledFooterButton, StyledCatalogTable } from '../styles';
-import { CatalogObject } from '../../types';
+import { CatalogObject, FieldPropTypes } from '../../types';
export const TableCatalog = ({
required,
diff --git a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/ValidatedInputField.tsx b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/ValidatedInputField.tsx
index ec2e239ac4..d6794f9a21 100644
--- a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/ValidatedInputField.tsx
+++ b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/ValidatedInputField.tsx
@@ -19,7 +19,7 @@
import React from 'react';
import { t } from '@superset-ui/core';
import ValidatedInput from 'src/components/Form/LabeledErrorBoundInput';
-import { FieldPropTypes } from '.';
+import { FieldPropTypes } from '../../types';
const FIELD_TEXT_MAP = {
account: {
diff --git a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/index.tsx b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/index.tsx
index e747b3c895..fc076b624f 100644
--- a/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/index.tsx
+++ b/superset-frontend/src/features/databases/DatabaseModal/DatabaseConnectionForm/index.tsx
@@ -17,7 +17,11 @@
* under the License.
*/
import React, { FormEvent } from 'react';
-import { SupersetTheme, JsonObject } from '@superset-ui/core';
+import {
+ SupersetTheme,
+ JsonObject,
+ getExtensionsRegistry,
+} from '@superset-ui/core';
import { InputProps } from 'antd/lib/input';
import { Form } from 'src/components/Form';
import {
@@ -31,13 +35,13 @@ import {
portField,
queryField,
usernameField,
- SSHTunnelSwitch,
} from './CommonParameters';
import { validatedInputField } from './ValidatedInputField';
import { EncryptedField } from './EncryptedField';
import { TableCatalog } from './TableCatalog';
import { formScrollableStyles, validatedFormStyles } from '../styles';
import { DatabaseForm, DatabaseObject } from '../../types';
+import SSHTunnelSwitch from '../SSHTunnelSwitch';
export const FormFieldOrder = [
'host',
@@ -59,34 +63,10 @@ export const FormFieldOrder = [
'ssh',
];
-export interface FieldPropTypes {
- required: boolean;
- hasTooltip?: boolean;
- tooltipText?: (value: any) => string;
- placeholder?: string;
- onParametersChange: (value: any) => string;
- onParametersUploadFileChange: (value: any) => string;
- changeMethods: { onParametersChange: (value: any) => string } & {
- onChange: (value: any) => string;
- } & {
- onQueryChange: (value: any) => string;
- } & { onParametersUploadFileChange: (value: any) => string } & {
- onAddTableCatalog: () => void;
- onRemoveTableCatalog: (idx: number) => void;
- } & {
- onExtraInputChange: (value: any) => void;
- onSSHTunnelParametersChange: (value: any) => string;
- };
- validationErrors: JsonObject | null;
- getValidation: () => void;
- clearValidationErrors: () => void;
- db?: DatabaseObject;
- field: string;
- isEditMode?: boolean;
- sslForced?: boolean;
- defaultDBName?: string;
- editNewDb?: boolean;
-}
+const extensionsRegistry = getExtensionsRegistry();
+
+const SSHTunnelSwitchComponent =
+ extensionsRegistry.get('ssh_tunnel.form.switch') ?? SSHTunnelSwitch;
const FORM_FIELD_MAP = {
host: hostField,
@@ -105,7 +85,7 @@ const FORM_FIELD_MAP = {
warehouse: validatedInputField,
role: validatedInputField,
account: validatedInputField,
- ssh: SSHTunnelSwitch,
+ ssh: SSHTunnelSwitchComponent,
};
interface DatabaseConnectionFormProps {
@@ -138,7 +118,7 @@ interface DatabaseConnectionFormProps {
}
const DatabaseConnectionForm = ({
- dbModel: { parameters },
+ dbModel,
db,
editNewDb,
getPlaceholder,
@@ -154,47 +134,51 @@ const DatabaseConnectionForm = ({
sslForced,
validationErrors,
clearValidationErrors,
-}: DatabaseConnectionFormProps) => (
- <Form>
- <div
- // @ts-ignore
- css={(theme: SupersetTheme) => [
- formScrollableStyles,
- validatedFormStyles(theme),
- ]}
- >
- {parameters &&
- FormFieldOrder.filter(
- (key: string) =>
- Object.keys(parameters.properties).includes(key) ||
- key === 'database_name',
- ).map(field =>
- FORM_FIELD_MAP[field]({
- required: parameters.required?.includes(field),
- changeMethods: {
- onParametersChange,
- onChange,
- onQueryChange,
- onParametersUploadFileChange,
- onAddTableCatalog,
- onRemoveTableCatalog,
- onExtraInputChange,
- },
- validationErrors,
- getValidation,
- clearValidationErrors,
- db,
- key: field,
- field,
- isEditMode,
- sslForced,
- editNewDb,
- placeholder: getPlaceholder ? getPlaceholder(field) : undefined,
- }),
- )}
- </div>
- </Form>
-);
+}: DatabaseConnectionFormProps) => {
+ const parameters = dbModel?.parameters;
+
+ return (
+ <Form>
+ <div
+ // @ts-ignore
+ css={(theme: SupersetTheme) => [
+ formScrollableStyles,
+ validatedFormStyles(theme),
+ ]}
+ >
+ {parameters &&
+ FormFieldOrder.filter(
+ (key: string) =>
+ Object.keys(parameters.properties).includes(key) ||
+ key === 'database_name',
+ ).map(field =>
+ FORM_FIELD_MAP[field]({
+ required: parameters.required?.includes(field),
+ changeMethods: {
+ onParametersChange,
+ onChange,
+ onQueryChange,
+ onParametersUploadFileChange,
+ onAddTableCatalog,
+ onRemoveTableCatalog,
+ onExtraInputChange,
+ },
+ validationErrors,
+ getValidation,
+ clearValidationErrors,
+ db,
+ key: field,
+ field,
+ isEditMode,
+ sslForced,
+ editNewDb,
+ placeholder: getPlaceholder ? getPlaceholder(field) : undefined,
+ }),
+ )}
+ </div>
+ </Form>
+ );
+};
export const FormFieldMap = FORM_FIELD_MAP;
export default DatabaseConnectionForm;
diff --git a/superset-frontend/src/features/databases/DatabaseModal/SSHTunnelForm.tsx b/superset-frontend/src/features/databases/DatabaseModal/SSHTunnelForm.tsx
index 7823d82faf..e0d1b16ff2 100644
--- a/superset-frontend/src/features/databases/DatabaseModal/SSHTunnelForm.tsx
+++ b/superset-frontend/src/features/databases/DatabaseModal/SSHTunnelForm.tsx
@@ -16,7 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
-import React, { EventHandler, ChangeEvent, useState } from 'react';
+import React, { useState } from 'react';
import { t, styled } from '@superset-ui/core';
import { AntdForm, Col, Row } from 'src/components';
import { Form, FormLabel } from 'src/components/Form';
@@ -24,7 +24,7 @@ import { Radio } from 'src/components/Radio';
import { Input, TextArea } from 'src/components/Input';
import { Input as AntdInput, Tooltip } from 'antd';
import { EyeInvisibleOutlined, EyeOutlined } from '@ant-design/icons';
-import { DatabaseObject } from '../types';
+import { DatabaseObject, FieldPropTypes } from '../types';
import { AuthType } from '.';
const StyledDiv = styled.div`
@@ -54,9 +54,7 @@ const SSHTunnelForm = ({
setSSHTunnelLoginMethod,
}: {
db: DatabaseObject | null;
- onSSHTunnelParametersChange: EventHandler<
- ChangeEvent<HTMLInputElement | HTMLTextAreaElement>
- >;
+ onSSHTunnelParametersChange: FieldPropTypes['changeMethods']['onSSHTunnelParametersChange'];
setSSHTunnelLoginMethod: (method: AuthType) => void;
}) => {
const [usePassword, setUsePassword] = useState<AuthType>(AuthType.Password);
@@ -86,9 +84,9 @@ const SSHTunnelForm = ({
</FormLabel>
<Input
name="server_port"
- type="text"
placeholder={t('22')}
- value={db?.ssh_tunnel?.server_port || ''}
+ type="number"
+ value={db?.ssh_tunnel?.server_port}
onChange={onSSHTunnelParametersChange}
data-test="ssh-tunnel-server_port-input"
/>
diff --git a/superset-frontend/src/features/databases/DatabaseModal/SSHTunnelSwitch.test.tsx b/superset-frontend/src/features/databases/DatabaseModal/SSHTunnelSwitch.test.tsx
new file mode 100644
index 0000000000..fef205acf2
--- /dev/null
+++ b/superset-frontend/src/features/databases/DatabaseModal/SSHTunnelSwitch.test.tsx
@@ -0,0 +1,162 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+import React from 'react';
+import { render, screen } from 'spec/helpers/testing-library';
+import userEvent from '@testing-library/user-event';
+import SSHTunnelSwitch from './SSHTunnelSwitch';
+import { DatabaseForm, DatabaseObject } from '../types';
+
+jest.mock('@superset-ui/core', () => ({
+ ...jest.requireActual('@superset-ui/core'),
+ isFeatureEnabled: jest.fn().mockReturnValue(true),
+}));
+
+jest.mock('src/components', () => ({
+ AntdSwitch: ({
+ checked,
+ onChange,
+ }: {
+ checked: boolean;
+ onChange: (checked: boolean) => void;
+ }) => (
+ <button
+ onClick={() => onChange(!checked)}
+ aria-checked={checked}
+ role="switch"
+ type="button"
+ >
+ {checked ? 'ON' : 'OFF'}
+ </button>
+ ),
+}));
+
+const mockChangeMethods = {
+ onParametersChange: jest.fn(),
+};
+
+const mockDbModel = {
+ engine: 'mysql',
+ engine_information: {
+ disable_ssh_tunneling: false,
+ },
+} as DatabaseForm;
+
+const defaultDb = {
+ parameters: { ssh: false },
+ ssh_tunnel: {},
+ engine: 'mysql',
+} as DatabaseObject;
+
+afterEach(() => {
+ jest.clearAllMocks();
+});
+
+test('Renders SSH Tunnel switch enabled by default and toggles its state', () => {
+ render(
+ <SSHTunnelSwitch
+ changeMethods={mockChangeMethods}
+ clearValidationErrors={jest.fn}
+ db={defaultDb}
+ dbModel={mockDbModel}
+ />,
+ );
+ const switchButton = screen.getByRole('switch');
+ expect(switchButton).toHaveTextContent('OFF');
+ userEvent.click(switchButton);
+ expect(mockChangeMethods.onParametersChange).toHaveBeenCalledWith({
+ target: { type: 'toggle', name: 'ssh', checked: true, value: true },
+ });
+ expect(switchButton).toHaveTextContent('ON');
+});
+
+test('Does not render if SSH Tunnel is disabled', () => {
+ render(
+ <SSHTunnelSwitch
+ changeMethods={mockChangeMethods}
+ clearValidationErrors={jest.fn}
+ db={defaultDb}
+ dbModel={{
+ ...mockDbModel,
+ engine_information: {
+ disable_ssh_tunneling: true,
+ supports_file_upload: false,
+ },
+ }}
+ />,
+ );
+ expect(screen.queryByRole('switch')).not.toBeInTheDocument();
+});
+
+test('Checks the switch based on db.parameters.ssh', () => {
+ const dbWithSSHTunnelEnabled = {
+ ...defaultDb,
+ parameters: { ssh: true },
+ } as DatabaseObject;
+ render(
+ <SSHTunnelSwitch
+ changeMethods={mockChangeMethods}
+ clearValidationErrors={jest.fn}
+ db={dbWithSSHTunnelEnabled}
+ dbModel={mockDbModel}
+ />,
+ );
+ expect(screen.getByRole('switch')).toHaveTextContent('ON');
+});
+
+test('Calls onParametersChange with true if SSH Tunnel info exists', () => {
+ const dbWithSSHTunnelInfo = {
+ ...defaultDb,
+ parameters: { ssh: undefined },
+ ssh_tunnel: { host: 'example.com' },
+ } as DatabaseObject;
+ render(
+ <SSHTunnelSwitch
+ changeMethods={mockChangeMethods}
+ clearValidationErrors={jest.fn}
+ db={dbWithSSHTunnelInfo}
+ dbModel={mockDbModel}
+ />,
+ );
+ expect(mockChangeMethods.onParametersChange).toHaveBeenCalledWith({
+ target: { type: 'toggle', name: 'ssh', checked: true, value: true },
+ });
+});
+
+test('Displays tooltip text on hover over the InfoTooltip', async () => {
+ const tooltipText = 'SSH Tunnel configuration parameters';
+ render(
+ <SSHTunnelSwitch
+ changeMethods={mockChangeMethods}
+ clearValidationErrors={jest.fn}
+ db={defaultDb}
+ dbModel={mockDbModel}
+ />,
+ );
+
+ const infoTooltipTrigger = screen.getByRole('img', {
+ name: 'info-solid_small',
+ });
+ expect(infoTooltipTrigger).toBeInTheDocument();
+
+ userEvent.hover(infoTooltipTrigger);
+
+ const tooltip = await screen.findByText(tooltipText);
+
+ expect(tooltip).toBeInTheDocument();
+});
diff --git a/superset-frontend/src/features/databases/DatabaseModal/SSHTunnelSwitch.tsx b/superset-frontend/src/features/databases/DatabaseModal/SSHTunnelSwitch.tsx
index 388e3c83b1..cf96864a3d 100644
--- a/superset-frontend/src/features/databases/DatabaseModal/SSHTunnelSwitch.tsx
+++ b/superset-frontend/src/features/databases/DatabaseModal/SSHTunnelSwitch.tsx
@@ -16,35 +16,73 @@
* specific language governing permissions and limitations
* under the License.
*/
-import React from 'react';
-import { t, SupersetTheme, SwitchProps } from '@superset-ui/core';
+import React, { useEffect, useState } from 'react';
+import {
+ t,
+ SupersetTheme,
+ isFeatureEnabled,
+ FeatureFlag,
+} from '@superset-ui/core';
import { AntdSwitch } from 'src/components';
import InfoTooltip from 'src/components/InfoTooltip';
import { isEmpty } from 'lodash';
-import { ActionType } from '.';
import { infoTooltip, toggleStyle } from './styles';
+import { SwitchProps } from '../types';
const SSHTunnelSwitch = ({
- isEditMode,
- dbFetched,
- useSSHTunneling,
- setUseSSHTunneling,
- setDB,
- isSSHTunneling,
-}: SwitchProps) =>
- isSSHTunneling ? (
+ clearValidationErrors,
+ changeMethods,
+ db,
+ dbModel,
+}: SwitchProps) => {
+ const [isChecked, setChecked] = useState(false);
+ const sshTunnelEnabled = isFeatureEnabled(FeatureFlag.SshTunneling);
+ const disableSSHTunnelingForEngine =
+ dbModel?.engine_information?.disable_ssh_tunneling || false;
+ const isSSHTunnelEnabled = sshTunnelEnabled && !disableSSHTunnelingForEngine;
+
+ const handleOnChange = (changed: boolean) => {
+ setChecked(changed);
+ changeMethods.onParametersChange({
+ target: {
+ type: 'toggle',
+ name: 'ssh',
+ checked: true,
+ value: changed,
+ },
+ });
+ clearValidationErrors();
+ };
+
+ useEffect(() => {
+ if (isSSHTunnelEnabled && db?.parameters?.ssh !== undefined) {
+ setChecked(db.parameters.ssh);
+ }
+ }, [db?.parameters?.ssh, isSSHTunnelEnabled]);
+
+ useEffect(() => {
+ if (
+ isSSHTunnelEnabled &&
+ db?.parameters?.ssh === undefined &&
+ !isEmpty(db?.ssh_tunnel)
+ ) {
+ // reflecting the state of the ssh tunnel on first load
+ changeMethods.onParametersChange({
+ target: {
+ type: 'toggle',
+ name: 'ssh',
+ checked: true,
+ value: true,
+ },
+ });
+ }
+ }, [changeMethods, db?.parameters?.ssh, db?.ssh_tunnel, isSSHTunnelEnabled]);
+
+ return isSSHTunnelEnabled ? (
<div css={(theme: SupersetTheme) => infoTooltip(theme)}>
<AntdSwitch
- disabled={isEditMode && !isEmpty(dbFetched?.ssh_tunnel)}
- checked={useSSHTunneling}
- onChange={changed => {
- setUseSSHTunneling(changed);
- if (!changed) {
- setDB({
- type: ActionType.RemoveSSHTunnelConfig,
- });
- }
- }}
+ checked={isChecked}
+ onChange={handleOnChange}
data-test="ssh-tunnel-switch"
/>
<span css={toggleStyle}>{t('SSH Tunnel')}</span>
@@ -55,4 +93,6 @@ const SSHTunnelSwitch = ({
/>
</div>
) : null;
+};
+
export default SSHTunnelSwitch;
diff --git a/superset-frontend/src/features/databases/DatabaseModal/index.test.tsx b/superset-frontend/src/features/databases/DatabaseModal/index.test.tsx
index 0f60857f06..7e8018b25f 100644
--- a/superset-frontend/src/features/databases/DatabaseModal/index.test.tsx
+++ b/superset-frontend/src/features/databases/DatabaseModal/index.test.tsx
@@ -16,6 +16,9 @@
* specific language governing permissions and limitations
* under the License.
*/
+
+// TODO: These tests should be made atomic in separate files
+
import React from 'react';
import fetchMock from 'fetch-mock';
import userEvent from '@testing-library/user-event';
@@ -1227,9 +1230,9 @@ describe('DatabaseModal', () => {
const SSHTunnelServerPortInput = screen.getByTestId(
'ssh-tunnel-server_port-input',
);
- expect(SSHTunnelServerPortInput).toHaveValue('');
+ expect(SSHTunnelServerPortInput).toHaveValue(null);
userEvent.type(SSHTunnelServerPortInput, '22');
- expect(SSHTunnelServerPortInput).toHaveValue('22');
+ expect(SSHTunnelServerPortInput).toHaveValue(22);
const SSHTunnelUsernameInput = screen.getByTestId(
'ssh-tunnel-username-input',
);
@@ -1263,9 +1266,9 @@ describe('DatabaseModal', () => {
const SSHTunnelServerPortInput = screen.getByTestId(
'ssh-tunnel-server_port-input',
);
- expect(SSHTunnelServerPortInput).toHaveValue('');
+ expect(SSHTunnelServerPortInput).toHaveValue(null);
userEvent.type(SSHTunnelServerPortInput, '22');
- expect(SSHTunnelServerPortInput).toHaveValue('22');
+ expect(SSHTunnelServerPortInput).toHaveValue(22);
const SSHTunnelUsernameInput = screen.getByTestId(
'ssh-tunnel-username-input',
);
diff --git a/superset-frontend/src/features/databases/DatabaseModal/index.tsx b/superset-frontend/src/features/databases/DatabaseModal/index.tsx
index 60ae032feb..47c9a8b658 100644
--- a/superset-frontend/src/features/databases/DatabaseModal/index.tsx
+++ b/superset-frontend/src/features/databases/DatabaseModal/index.tsx
@@ -20,8 +20,6 @@ import {
t,
styled,
SupersetTheme,
- FeatureFlag,
- isFeatureEnabled,
getExtensionsRegistry,
} from '@superset-ui/core';
import React, {
@@ -31,6 +29,7 @@ import React, {
useState,
useReducer,
Reducer,
+ useCallback,
} from 'react';
import { useHistory } from 'react-router-dom';
import { setItem, LocalStorageKeys } from 'src/utils/localStorageHelpers';
@@ -65,6 +64,7 @@ import {
CatalogObject,
Engines,
ExtraJson,
+ CustomTextType,
} from '../types';
import ExtraOptions from './ExtraOptions';
import SqlAlchemyForm from './SqlAlchemyForm';
@@ -208,8 +208,8 @@ export type DBReducerActionType =
| {
type:
| ActionType.Reset
- | ActionType.AddTableCatalogSheet
- | ActionType.RemoveSSHTunnelConfig;
+ | ActionType.RemoveSSHTunnelConfig
+ | ActionType.AddTableCatalogSheet;
}
| {
type: ActionType.RemoveTableCatalogSheet;
@@ -595,7 +595,9 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
const SSHTunnelSwitchComponent =
extensionsRegistry.get('ssh_tunnel.form.switch') ?? SSHTunnelSwitch;
- const [useSSHTunneling, setUseSSHTunneling] = useState<boolean>(false);
+ const [useSSHTunneling, setUseSSHTunneling] = useState<boolean | undefined>(
+ undefined,
+ );
let dbConfigExtraExtension = extensionsRegistry.get(
'databaseconnection.extraOption',
@@ -618,14 +620,6 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
const dbImages = getDatabaseImages();
const connectionAlert = getConnectionAlert();
const isEditMode = !!databaseId;
- const disableSSHTunnelingForEngine = (
- availableDbs?.databases?.find(
- (DB: DatabaseObject) =>
- DB.backend === db?.engine || DB.engine === db?.engine,
- ) as DatabaseObject
- )?.engine_information?.disable_ssh_tunneling;
- const isSSHTunneling =
- isFeatureEnabled(FeatureFlag.SshTunneling) && !disableSSHTunnelingForEngine;
const hasAlert =
connectionAlert || !!(db?.engine && engineSpecificAlertMapping[db.engine]);
const useSqlAlchemyForm =
@@ -659,7 +653,13 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
extra: db?.extra,
masked_encrypted_extra: db?.masked_encrypted_extra || '',
server_cert: db?.server_cert || undefined,
- ssh_tunnel: db?.ssh_tunnel || undefined,
+ ssh_tunnel:
+ !isEmpty(db?.ssh_tunnel) && useSSHTunneling
+ ? {
+ ...db.ssh_tunnel,
+ server_port: Number(db.ssh_tunnel!.server_port),
+ }
+ : undefined,
};
setTestInProgress(true);
testDatabaseConnection(
@@ -687,10 +687,36 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
return false;
};
+ const onChange = useCallback(
+ (
+ type: DBReducerActionType['type'],
+ payload: CustomTextType | DBReducerPayloadType,
+ ) => {
+ setDB({ type, payload } as DBReducerActionType);
+ },
+ [],
+ );
+
+ const handleClearValidationErrors = useCallback(() => {
+ setValidationErrors(null);
+ }, [setValidationErrors]);
+
+ const handleParametersChange = useCallback(
+ ({ target }: { target: HTMLInputElement }) => {
+ onChange(ActionType.ParametersChange, {
+ type: target.type,
+ name: target.name,
+ checked: target.checked,
+ value: target.value,
+ });
+ },
+ [onChange],
+ );
+
const onClose = () => {
setDB({ type: ActionType.Reset });
setHasConnectedDb(false);
- setValidationErrors(null); // reset validation errors on close
+ handleClearValidationErrors(); // reset validation errors on close
clearError();
setEditNewDb(false);
setFileList([]);
@@ -705,7 +731,7 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
setSSHTunnelPrivateKeys({});
setSSHTunnelPrivateKeyPasswords({});
setConfirmedOverwrite(false);
- setUseSSHTunneling(false);
+ setUseSSHTunneling(undefined);
onHide();
};
@@ -729,12 +755,11 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
setImportingErrorMessage(msg);
});
- const onChange = (type: any, payload: any) => {
- setDB({ type, payload } as DBReducerActionType);
- };
-
const onSave = async () => {
let dbConfigExtraExtensionOnSaveError;
+
+ setLoading(true);
+
dbConfigExtraExtension
?.onSave(extraExtensionComponentState, db)
.then(({ error }: { error: any }) => {
@@ -743,6 +768,7 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
addDangerToast(error);
}
});
+
if (dbConfigExtraExtensionOnSaveError) {
setLoading(false);
return;
@@ -762,17 +788,13 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
});
}
- // only do validation for non ssh tunnel connections
- if (!dbToUpdate?.ssh_tunnel) {
- // make sure that button spinner animates
- setLoading(true);
- const errors = await getValidation(dbToUpdate, true);
- if ((validationErrors && !isEmpty(validationErrors)) || errors) {
- setLoading(false);
- return;
- }
- // end spinner animation
+ const errors = await getValidation(dbToUpdate, true);
+ if (!isEmpty(validationErrors) || errors?.length) {
+ addDangerToast(
+ t('Connection failed, please check your connection settings.'),
+ );
setLoading(false);
+ return;
}
const parameters_schema = isEditMode
@@ -829,7 +851,12 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
});
}
- setLoading(true);
+ // strictly checking for false as an indication that the toggle got unchecked
+ if (useSSHTunneling === false) {
+ // remove ssh tunnel
+ dbToUpdate.ssh_tunnel = null;
+ }
+
if (db?.id) {
const result = await updateResource(
db.id as number,
@@ -1282,10 +1309,10 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
}, [sshPrivateKeyPasswordNeeded]);
useEffect(() => {
- if (db && isSSHTunneling) {
- setUseSSHTunneling(!isEmpty(db?.ssh_tunnel));
+ if (db?.parameters?.ssh !== undefined) {
+ setUseSSHTunneling(db.parameters.ssh);
}
- }, [db, isSSHTunneling]);
+ }, [db?.parameters?.ssh]);
const onDbImport = async (info: UploadChangeParam) => {
setImportingErrorMessage('');
@@ -1550,17 +1577,14 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
const renderSSHTunnelForm = () => (
<SSHTunnelForm
db={db as DatabaseObject}
- onSSHTunnelParametersChange={({
- target,
- }: {
- target: HTMLInputElement | HTMLTextAreaElement;
- }) =>
+ onSSHTunnelParametersChange={({ target }) => {
onChange(ActionType.ParametersSSHTunnelChange, {
type: target.type,
name: target.name,
value: target.value,
- })
- }
+ });
+ handleClearValidationErrors();
+ }}
setSSHTunnelLoginMethod={(method: AuthType) =>
setDB({
type: ActionType.SetSSHTunnelLoginMethod,
@@ -1623,14 +1647,7 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
payload: { indexToDelete: idx },
});
}}
- onParametersChange={({ target }: { target: HTMLInputElement }) =>
- onChange(ActionType.ParametersChange, {
- type: target.type,
- name: target.name,
- checked: target.checked,
- value: target.value,
- })
- }
+ onParametersChange={handleParametersChange}
onChange={({ target }: { target: HTMLInputElement }) =>
onChange(ActionType.TextChange, {
name: target.name,
@@ -1640,9 +1657,9 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
getValidation={() => getValidation(db)}
validationErrors={validationErrors}
getPlaceholder={getPlaceholder}
- clearValidationErrors={() => setValidationErrors(null)}
+ clearValidationErrors={handleClearValidationErrors}
/>
- {db?.parameters?.ssh && (
+ {useSSHTunneling && (
<SSHTunnelContainer>{renderSSHTunnelForm()}</SSHTunnelContainer>
)}
</>
@@ -1792,13 +1809,12 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
testInProgress={testInProgress}
>
<SSHTunnelSwitchComponent
- isEditMode={isEditMode}
- dbFetched={dbFetched}
- disableSSHTunnelingForEngine={disableSSHTunnelingForEngine}
- useSSHTunneling={useSSHTunneling}
- setUseSSHTunneling={setUseSSHTunneling}
- setDB={setDB}
- isSSHTunneling={isSSHTunneling}
+ dbModel={dbModel}
+ db={db as DatabaseObject}
+ changeMethods={{
+ onParametersChange: handleParametersChange,
+ }}
+ clearValidationErrors={handleClearValidationErrors}
/>
{useSSHTunneling && renderSSHTunnelForm()}
</SqlAlchemyForm>
diff --git a/superset-frontend/src/features/databases/types.ts b/superset-frontend/src/features/databases/types.ts
index 50e535f9b1..58d533c7be 100644
--- a/superset-frontend/src/features/databases/types.ts
+++ b/superset-frontend/src/features/databases/types.ts
@@ -1,3 +1,7 @@
+import { JsonObject } from '@superset-ui/core';
+import { InputProps } from 'antd/lib/input';
+import { ChangeEvent, EventHandler, FormEvent } from 'react';
+
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
@@ -108,7 +112,7 @@ export type DatabaseObject = {
};
// SSH Tunnel information
- ssh_tunnel?: SSHTunnelObject;
+ ssh_tunnel?: SSHTunnelObject | null;
};
export type DatabaseForm = {
@@ -195,6 +199,10 @@ export type DatabaseForm = {
};
preferred: boolean;
sqlalchemy_uri_placeholder: string;
+ engine_information: {
+ supports_file_upload: boolean;
+ disable_ssh_tunneling: boolean;
+ };
};
// the values should align with the database
@@ -231,3 +239,73 @@ export interface ExtraJson {
};
version?: string;
}
+
+export type CustomTextType = {
+ value?: string | boolean | number;
+ type?: string | null;
+ name?: string;
+ checked?: boolean;
+};
+
+type CustomHTMLInputElement = Omit<Partial<CustomTextType>, 'value' | 'type'> &
+ CustomTextType;
+
+type CustomHTMLTextAreaElement = Omit<
+ Partial<CustomTextType>,
+ 'value' | 'type'
+> &
+ CustomTextType;
+
+export type CustomParametersChangeType<T = CustomTextType> =
+ | FormEvent<InputProps>
+ | { target: T };
+
+export type CustomEventHandlerType = EventHandler<
+ ChangeEvent<CustomHTMLInputElement | CustomHTMLTextAreaElement>
+>;
+
+export interface FieldPropTypes {
+ required: boolean;
+ hasTooltip?: boolean;
+ tooltipText?: (value: any) => string;
+ placeholder?: string;
+ onParametersChange: (event: CustomParametersChangeType) => void;
+ onParametersUploadFileChange: (value: any) => string;
+ changeMethods: {
+ onParametersChange: (event: CustomParametersChangeType) => void;
+ } & {
+ onChange: (value: any) => string;
+ } & {
+ onQueryChange: (value: any) => string;
+ } & { onParametersUploadFileChange: (value: any) => string } & {
+ onAddTableCatalog: () => void;
+ onRemoveTableCatalog: (idx: number) => void;
+ } & {
+ onExtraInputChange: (value: any) => void;
+ onSSHTunnelParametersChange: CustomEventHandlerType;
+ };
+ validationErrors: JsonObject | null;
+ getValidation: () => void;
+ clearValidationErrors: () => void;
+ db?: DatabaseObject;
+ dbModel?: DatabaseForm;
+ field: string;
+ isEditMode?: boolean;
+ sslForced?: boolean;
+ defaultDBName?: string;
+ editNewDb?: boolean;
+}
+
+type ChangeMethodsType = FieldPropTypes['changeMethods'];
+
+// changeMethods compatibility with dynamic forms
+type SwitchPropsChangeMethodsType = {
+ onParametersChange: ChangeMethodsType['onParametersChange'];
+};
+
+export type SwitchProps = {
+ dbModel: DatabaseForm;
+ db: DatabaseObject;
+ changeMethods: SwitchPropsChangeMethodsType;
+ clearValidationErrors: () => void;
+};
diff --git a/superset-frontend/src/views/CRUD/hooks.ts b/superset-frontend/src/views/CRUD/hooks.ts
index 85f7c60252..8f31f2fcdd 100644
--- a/superset-frontend/src/views/CRUD/hooks.ts
+++ b/superset-frontend/src/views/CRUD/hooks.ts
@@ -35,7 +35,8 @@ import Chart, { Slice } from 'src/types/Chart';
import copyTextToClipboard from 'src/utils/copy';
import { getClientErrorObject } from 'src/utils/getClientErrorObject';
import SupersetText from 'src/utils/textUtils';
-import { FavoriteStatus, ImportResourceName, DatabaseObject } from './types';
+import { DatabaseObject } from 'src/features/databases/types';
+import { FavoriteStatus, ImportResourceName } from './types';
interface ListViewResourceState<D extends object = any> {
loading: boolean;
@@ -691,7 +692,7 @@ export const getDatabaseDocumentationLinks = () =>
SupersetText.DB_CONNECTION_DOC_LINKS;
export const testDatabaseConnection = (
- connection: DatabaseObject,
+ connection: Partial<DatabaseObject>,
handleErrorMsg: (errorMsg: string) => void,
addSuccessToast: (arg0: string) => void,
) => {
@@ -745,7 +746,7 @@ export function useDatabaseValidation() {
const getValidation = useCallback(
(database: Partial<DatabaseObject> | null, onCreate = false) => {
if (database?.parameters?.ssh) {
- // when ssh tunnel is enabled we don't want to render any validation errors
+ // TODO: /validate_parameters/ and related utils should support ssh tunnel
setValidationErrors(null);
return [];
}
diff --git a/superset/commands/database/create.py b/superset/commands/database/create.py
index cde9dd8e88..9efb39b75a 100644
--- a/superset/commands/database/create.py
+++ b/superset/commands/database/create.py
@@ -19,6 +19,7 @@ from typing import Any, Optional
from flask import current_app
from flask_appbuilder.models.sqla import Model
+from flask_babel import gettext as _
from marshmallow import ValidationError
from superset import is_feature_enabled
@@ -33,6 +34,7 @@ from superset.commands.database.exceptions import (
from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelCreateFailedError,
+ SSHTunnelDatabasePortError,
SSHTunnelingNotEnabledError,
SSHTunnelInvalidError,
)
@@ -57,7 +59,11 @@ class CreateDatabaseCommand(BaseCommand):
try:
# Test connection before starting create transaction
TestConnectionDatabaseCommand(self._properties).run()
- except (SupersetErrorsException, SSHTunnelingNotEnabledError) as ex:
+ except (
+ SupersetErrorsException,
+ SSHTunnelingNotEnabledError,
+ SSHTunnelDatabasePortError,
+ ) as ex:
event_logger.log_with_context(
action=f"db_creation_failed.{ex.__class__.__name__}",
engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
@@ -103,6 +109,7 @@ class CreateDatabaseCommand(BaseCommand):
SSHTunnelInvalidError,
SSHTunnelCreateFailedError,
SSHTunnelingNotEnabledError,
+ SSHTunnelDatabasePortError,
) as ex:
db.session.rollback()
event_logger.log_with_context(
@@ -140,6 +147,7 @@ class CreateDatabaseCommand(BaseCommand):
# Check database_name uniqueness
if not DatabaseDAO.validate_uniqueness(database_name):
exceptions.append(DatabaseExistsValidationError())
+
if exceptions:
exception = DatabaseInvalidError()
exception.extend(exceptions)
diff --git a/superset/commands/database/ssh_tunnel/create.py b/superset/commands/database/ssh_tunnel/create.py
index cbfee3ce2a..287accc5aa 100644
--- a/superset/commands/database/ssh_tunnel/create.py
+++ b/superset/commands/database/ssh_tunnel/create.py
@@ -23,11 +23,13 @@ from marshmallow import ValidationError
from superset.commands.base import BaseCommand
from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelCreateFailedError,
+ SSHTunnelDatabasePortError,
SSHTunnelInvalidError,
SSHTunnelRequiredFieldValidationError,
)
from superset.daos.database import SSHTunnelDAO
from superset.daos.exceptions import DAOCreateFailedError
+from superset.databases.utils import make_url_safe
from superset.extensions import event_logger
from superset.models.core import Database
@@ -35,9 +37,12 @@ logger = logging.getLogger(__name__)
class CreateSSHTunnelCommand(BaseCommand):
+ _database: Database
+
def __init__(self, database: Database, data: dict[str, Any]):
self._properties = data.copy()
self._properties["database"] = database
+ self._database = database
def run(self) -> Model:
try:
@@ -57,16 +62,22 @@ class CreateSSHTunnelCommand(BaseCommand):
server_address: Optional[str] = self._properties.get("server_address")
server_port: Optional[int] = self._properties.get("server_port")
username: Optional[str] = self._properties.get("username")
+ password: Optional[str] = self._properties.get("password")
private_key: Optional[str] = self._properties.get("private_key")
private_key_password: Optional[str] = self._properties.get(
"private_key_password"
)
+ url = make_url_safe(self._database.sqlalchemy_uri)
+ if not url.port:
+ raise SSHTunnelDatabasePortError()
if not server_address:
exceptions.append(SSHTunnelRequiredFieldValidationError("server_address"))
if not server_port:
exceptions.append(SSHTunnelRequiredFieldValidationError("server_port"))
if not username:
exceptions.append(SSHTunnelRequiredFieldValidationError("username"))
+ if not private_key and not password:
+ exceptions.append(SSHTunnelRequiredFieldValidationError("password"))
if private_key_password and private_key is None:
exceptions.append(SSHTunnelRequiredFieldValidationError("private_key"))
if exceptions:
diff --git a/superset/commands/database/ssh_tunnel/exceptions.py b/superset/commands/database/ssh_tunnel/exceptions.py
index 0e3f91cae6..a0def8c087 100644
--- a/superset/commands/database/ssh_tunnel/exceptions.py
+++ b/superset/commands/database/ssh_tunnel/exceptions.py
@@ -38,6 +38,10 @@ class SSHTunnelInvalidError(CommandInvalidError):
message = _("SSH Tunnel parameters are invalid.")
+class SSHTunnelDatabasePortError(CommandInvalidError):
+ message = _("A database port is required when connecting via SSH Tunnel.")
+
+
class SSHTunnelUpdateFailedError(UpdateFailedError):
message = _("SSH Tunnel could not be updated.")
diff --git a/superset/commands/database/ssh_tunnel/update.py b/superset/commands/database/ssh_tunnel/update.py
index ae7ee78afe..d0dd14a5b2 100644
--- a/superset/commands/database/ssh_tunnel/update.py
+++ b/superset/commands/database/ssh_tunnel/update.py
@@ -21,6 +21,7 @@ from flask_appbuilder.models.sqla import Model
from superset.commands.base import BaseCommand
from superset.commands.database.ssh_tunnel.exceptions import (
+ SSHTunnelDatabasePortError,
SSHTunnelInvalidError,
SSHTunnelNotFoundError,
SSHTunnelRequiredFieldValidationError,
@@ -29,6 +30,7 @@ from superset.commands.database.ssh_tunnel.exceptions import (
from superset.daos.database import SSHTunnelDAO
from superset.daos.exceptions import DAOUpdateFailedError
from superset.databases.ssh_tunnel.models import SSHTunnel
+from superset.databases.utils import make_url_safe
logger = logging.getLogger(__name__)
@@ -39,20 +41,33 @@ class UpdateSSHTunnelCommand(BaseCommand):
self._model_id = model_id
self._model: Optional[SSHTunnel] = None
- def run(self) -> Model:
+ def run(self) -> Optional[Model]:
self.validate()
try:
- if self._model is not None: # So we dont get incompatible types error
- tunnel = SSHTunnelDAO.update(self._model, self._properties)
+ if self._model is None:
+ return None
+
+ # unset password if private key is provided
+ if self._properties.get("private_key"):
+ self._properties["password"] = None
+
+ # unset private key and password if password is provided
+ if self._properties.get("password"):
+ self._properties["private_key"] = None
+ self._properties["private_key_password"] = None
+
+ tunnel = SSHTunnelDAO.update(self._model, self._properties)
+ return tunnel
except DAOUpdateFailedError as ex:
raise SSHTunnelUpdateFailedError() from ex
- return tunnel
def validate(self) -> None:
# Validate/populate model exists
self._model = SSHTunnelDAO.find_by_id(self._model_id)
if not self._model:
raise SSHTunnelNotFoundError()
+
+ url = make_url_safe(self._model.database.sqlalchemy_uri)
private_key: Optional[str] = self._properties.get("private_key")
private_key_password: Optional[str] = self._properties.get(
"private_key_password"
@@ -61,3 +76,5 @@ class UpdateSSHTunnelCommand(BaseCommand):
raise SSHTunnelInvalidError(
exceptions=[SSHTunnelRequiredFieldValidationError("private_key")]
)
+ if not url.port:
+ raise SSHTunnelDatabasePortError()
diff --git a/superset/commands/database/test_connection.py b/superset/commands/database/test_connection.py
index 0ffdf3ddd9..431918c6bc 100644
--- a/superset/commands/database/test_connection.py
+++ b/superset/commands/database/test_connection.py
@@ -32,7 +32,10 @@ from superset.commands.database.exceptions import (
DatabaseTestConnectionDriverError,
DatabaseTestConnectionUnexpectedError,
)
-from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelingNotEnabledError
+from superset.commands.database.ssh_tunnel.exceptions import (
+ SSHTunnelDatabasePortError,
+ SSHTunnelingNotEnabledError,
+)
from superset.daos.database import DatabaseDAO, SSHTunnelDAO
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.databases.utils import make_url_safe
@@ -61,20 +64,22 @@ def get_log_connection_action(
class TestConnectionDatabaseCommand(BaseCommand):
+ _model: Optional[Database] = None
+ _context: dict[str, Any]
+ _uri: str
+
def __init__(self, data: dict[str, Any]):
self._properties = data.copy()
- self._model: Optional[Database] = None
- def run(self) -> None: # pylint: disable=too-many-statements, too-many-branches
- self.validate()
- ex_str = ""
+ if (database_name := self._properties.get("database_name")) is not None:
+ self._model = DatabaseDAO.get_database_by_name(database_name)
+
uri = self._properties.get("sqlalchemy_uri", "")
if self._model and uri == self._model.safe_sqlalchemy_uri():
uri = self._model.sqlalchemy_uri_decrypted
- ssh_tunnel = self._properties.get("ssh_tunnel")
- # context for error messages
url = make_url_safe(uri)
+
context = {
"hostname": url.host,
"password": url.password,
@@ -83,6 +88,14 @@ class TestConnectionDatabaseCommand(BaseCommand):
"database": url.database,
}
+ self._context = context
+ self._uri = uri
+
+ def run(self) -> None: # pylint: disable=too-many-statements
+ self.validate()
+ ex_str = ""
+ ssh_tunnel = self._properties.get("ssh_tunnel")
+
serialized_encrypted_extra = self._properties.get(
"masked_encrypted_extra",
"{}",
@@ -103,15 +116,12 @@ class TestConnectionDatabaseCommand(BaseCommand):
encrypted_extra=serialized_encrypted_extra,
)
- database.set_sqlalchemy_uri(uri)
+ database.set_sqlalchemy_uri(self._uri)
database.db_engine_spec.mutate_db_for_connection_test(database)
# Generate tunnel if present in the properties
if ssh_tunnel:
- if not is_feature_enabled("SSH_TUNNELING"):
- raise SSHTunnelingNotEnabledError()
- # If there's an existing tunnel for that DB we need to use the stored
- # password, private_key and private_key_password instead
+ # unmask password while allowing for updated values
if ssh_tunnel_id := ssh_tunnel.pop("id", None):
if existing_ssh_tunnel := SSHTunnelDAO.find_by_id(ssh_tunnel_id):
ssh_tunnel = unmask_password_info(
@@ -186,7 +196,7 @@ class TestConnectionDatabaseCommand(BaseCommand):
engine=database.db_engine_spec.__name__,
)
# check for custom errors (wrong username, wrong password, etc)
- errors = database.db_engine_spec.extract_errors(ex, context)
+ errors = database.db_engine_spec.extract_errors(ex, self._context)
raise SupersetErrorsException(errors) from ex
except SupersetSecurityException as ex:
event_logger.log_with_context(
@@ -221,9 +231,12 @@ class TestConnectionDatabaseCommand(BaseCommand):
),
engine=database.db_engine_spec.__name__,
)
- errors = database.db_engine_spec.extract_errors(ex, context)
+ errors = database.db_engine_spec.extract_errors(ex, self._context)
raise DatabaseTestConnectionUnexpectedError(errors) from ex
def validate(self) -> None:
- if (database_name := self._properties.get("database_name")) is not None:
- self._model = DatabaseDAO.get_database_by_name(database_name)
+ if self._properties.get("ssh_tunnel"):
+ if not is_feature_enabled("SSH_TUNNELING"):
+ raise SSHTunnelingNotEnabledError()
+ if not self._context.get("port"):
+ raise SSHTunnelDatabasePortError()
diff --git a/superset/commands/database/update.py b/superset/commands/database/update.py
index edc0ba1b98..5575d674a8 100644
--- a/superset/commands/database/update.py
+++ b/superset/commands/database/update.py
@@ -18,6 +18,7 @@ import logging
from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
+from flask_babel import gettext as _
from marshmallow import ValidationError
from superset import is_feature_enabled
@@ -30,8 +31,11 @@ from superset.commands.database.exceptions import (
DatabaseUpdateFailedError,
)
from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
+from superset.commands.database.ssh_tunnel.delete import DeleteSSHTunnelCommand
from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelCreateFailedError,
+ SSHTunnelDatabasePortError,
+ SSHTunnelDeleteFailedError,
SSHTunnelingNotEnabledError,
SSHTunnelInvalidError,
SSHTunnelUpdateFailedError,
@@ -47,15 +51,21 @@ logger = logging.getLogger(__name__)
class UpdateDatabaseCommand(BaseCommand):
+ _model: Optional[Database]
+
def __init__(self, model_id: int, data: dict[str, Any]):
self._properties = data.copy()
self._model_id = model_id
self._model: Optional[Database] = None
- def run(self) -> Model:
- self.validate()
+ def run(self) -> Model: # pylint: disable=too-many-statements, too-many-branches
+ self._model = DatabaseDAO.find_by_id(self._model_id)
+
if not self._model:
raise DatabaseNotFoundError()
+
+ self.validate()
+
old_database_name = self._model.database_name
# unmask ``encrypted_extra``
@@ -70,36 +80,59 @@ class UpdateDatabaseCommand(BaseCommand):
database = DatabaseDAO.update(self._model, self._properties, commit=False)
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
- if ssh_tunnel_properties := self._properties.get("ssh_tunnel"):
+ ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
+
+ if "ssh_tunnel" in self._properties:
if not is_feature_enabled("SSH_TUNNELING"):
db.session.rollback()
raise SSHTunnelingNotEnabledError()
- existing_ssh_tunnel_model = DatabaseDAO.get_ssh_tunnel(database.id)
- if existing_ssh_tunnel_model is None:
- # We couldn't found an existing tunnel so we need to create one
- try:
- CreateSSHTunnelCommand(database, ssh_tunnel_properties).run()
- except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex:
- # So we can show the original message
- raise ex
- except Exception as ex:
- raise DatabaseUpdateFailedError() from ex
- else:
- # We found an existing tunnel so we need to update it
+
+ if self._properties.get("ssh_tunnel") is None and ssh_tunnel:
+ # We need to remove the existing tunnel
try:
- UpdateSSHTunnelCommand(
- existing_ssh_tunnel_model.id, ssh_tunnel_properties
- ).run()
- except (SSHTunnelInvalidError, SSHTunnelUpdateFailedError) as ex:
- # So we can show the original message
+ DeleteSSHTunnelCommand(ssh_tunnel.id).run()
+ ssh_tunnel = None
+ except SSHTunnelDeleteFailedError as ex:
raise ex
except Exception as ex:
raise DatabaseUpdateFailedError() from ex
+ if ssh_tunnel_properties := self._properties.get("ssh_tunnel"):
+ if ssh_tunnel is None:
+ # We couldn't found an existing tunnel so we need to create one
+ try:
+ ssh_tunnel = CreateSSHTunnelCommand(
+ database, ssh_tunnel_properties
+ ).run()
+ except (
+ SSHTunnelInvalidError,
+ SSHTunnelCreateFailedError,
+ SSHTunnelDatabasePortError,
+ ) as ex:
+ # So we can show the original message
+ raise ex
+ except Exception as ex:
+ raise DatabaseUpdateFailedError() from ex
+ else:
+ # We found an existing tunnel so we need to update it
+ try:
+ ssh_tunnel_id = ssh_tunnel.id
+ ssh_tunnel = UpdateSSHTunnelCommand(
+ ssh_tunnel_id, ssh_tunnel_properties
+ ).run()
+ except (
+ SSHTunnelInvalidError,
+ SSHTunnelUpdateFailedError,
+ SSHTunnelDatabasePortError,
+ ) as ex:
+ # So we can show the original message
+ raise ex
+ except Exception as ex:
+ raise DatabaseUpdateFailedError() from ex
+
# adding a new database we always want to force refresh schema list
# TODO Improve this simplistic implementation for catching DB conn fails
try:
- ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
schemas = database.get_all_schema_names(ssh_tunnel=ssh_tunnel)
except Exception as ex:
db.session.rollback()
@@ -167,10 +200,6 @@ class UpdateDatabaseCommand(BaseCommand):
def validate(self) -> None:
exceptions: list[ValidationError] = []
- # Validate/populate model exists
- self._model = DatabaseDAO.find_by_id(self._model_id)
- if not self._model:
- raise DatabaseNotFoundError()
database_name: Optional[str] = self._properties.get("database_name")
if database_name:
# Check database_name uniqueness
diff --git a/superset/databases/api.py b/superset/databases/api.py
index 2f95bd0442..4d7d4c531a 100644
--- a/superset/databases/api.py
+++ b/superset/databases/api.py
@@ -47,6 +47,7 @@ from superset.commands.database.export import ExportDatabasesCommand
from superset.commands.database.importers.dispatcher import ImportDatabasesCommand
from superset.commands.database.ssh_tunnel.delete import DeleteSSHTunnelCommand
from superset.commands.database.ssh_tunnel.exceptions import (
+ SSHTunnelDatabasePortError,
SSHTunnelDeleteFailedError,
SSHTunnelingNotEnabledError,
)
@@ -415,7 +416,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
exc_info=True,
)
return self.response_422(message=str(ex))
- except SSHTunnelingNotEnabledError as ex:
+ except (SSHTunnelingNotEnabledError, SSHTunnelDatabasePortError) as ex:
return self.response_400(message=str(ex))
except SupersetException as ex:
return self.response(ex.status, message=ex.message)
@@ -500,7 +501,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
exc_info=True,
)
return self.response_422(message=str(ex))
- except SSHTunnelingNotEnabledError as ex:
+ except (SSHTunnelingNotEnabledError, SSHTunnelDatabasePortError) as ex:
return self.response_400(message=str(ex))
@expose("/<int:pk>", methods=("DELETE",))
@@ -918,7 +919,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
try:
TestConnectionDatabaseCommand(item).run()
return self.response(200, message="OK")
- except SSHTunnelingNotEnabledError as ex:
+ except (SSHTunnelingNotEnabledError, SSHTunnelDatabasePortError) as ex:
return self.response_400(message=str(ex))
@expose("/<int:pk>/related_objects/", methods=("GET",))
diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py
index ebabc16e87..0f9dc03723 100644
--- a/tests/integration_tests/databases/api_tests.py
+++ b/tests/integration_tests/databases/api_tests.py
@@ -35,6 +35,7 @@ from sqlalchemy.exc import DBAPIError
from sqlalchemy.sql import func
from superset import db, security_manager
+from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelDatabasePortError
from superset.connectors.sqla.models import SqlaTable
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.databases.utils import make_url_safe
@@ -336,6 +337,58 @@ class TestDatabaseApi(SupersetTestCase):
db.session.delete(model)
db.session.commit()
+ @mock.patch(
+ "superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
+ )
+ @mock.patch("superset.commands.database.create.is_feature_enabled")
+ @mock.patch(
+ "superset.models.core.Database.get_all_schema_names",
+ )
+ def test_create_database_with_missing_port_raises_error(
+ self,
+ mock_test_connection_database_command_run,
+ mock_create_is_feature_enabled,
+ mock_get_all_schema_names,
+ ):
+ """
+ Database API: Test that missing port raises SSHTunnelDatabaseError
+ """
+ mock_create_is_feature_enabled.return_value = True
+ self.login(username="admin")
+ example_db = get_example_database()
+ if example_db.backend == "sqlite":
+ return
+
+ modified_sqlalchemy_uri = "postgresql://foo:bar@localhost/test-db"
+
+ ssh_tunnel_properties = {
+ "server_address": "123.132.123.1",
+ "server_port": 8080,
+ "username": "foo",
+ "password": "bar",
+ }
+
+ database_data_with_ssh_tunnel = {
+ "database_name": "test-db-with-ssh-tunnel",
+ "sqlalchemy_uri": modified_sqlalchemy_uri,
+ "ssh_tunnel": ssh_tunnel_properties,
+ }
+
+ database_data_with_ssh_tunnel = {
+ "database_name": "test-db-with-ssh-tunnel",
+ "sqlalchemy_uri": modified_sqlalchemy_uri,
+ "ssh_tunnel": ssh_tunnel_properties,
+ }
+
+ uri = "api/v1/database/"
+ rv = self.client.post(uri, json=database_data_with_ssh_tunnel)
+ response = json.loads(rv.data.decode("utf-8"))
+ self.assertEqual(rv.status_code, 400)
+ self.assertEqual(
+ response.get("message"),
+ "A database port is required when connecting via SSH Tunnel.",
+ )
+
@mock.patch(
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)
@@ -397,6 +450,154 @@ class TestDatabaseApi(SupersetTestCase):
db.session.delete(model)
db.session.commit()
+ @mock.patch(
+ "superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
+ )
+ @mock.patch("superset.commands.database.create.is_feature_enabled")
+ @mock.patch("superset.commands.database.update.is_feature_enabled")
+ @mock.patch(
+ "superset.models.core.Database.get_all_schema_names",
+ )
+ def test_update_database_with_missing_port_raises_error(
+ self,
+ mock_test_connection_database_command_run,
+ mock_create_is_feature_enabled,
+ mock_update_is_feature_enabled,
+ mock_get_all_schema_names,
+ ):
+ """
+ Database API: Test that missing port raises SSHTunnelDatabaseError
+ """
+ mock_create_is_feature_enabled.return_value = True
+ mock_update_is_feature_enabled.return_value = True
+ self.login(username="admin")
+ example_db = get_example_database()
+ if example_db.backend == "sqlite":
+ return
+
+ modified_sqlalchemy_uri = "postgresql://foo:bar@localhost/test-db"
+
+ ssh_tunnel_properties = {
+ "server_address": "123.132.123.1",
+ "server_port": 8080,
+ "username": "foo",
+ "password": "bar",
+ }
+
+ database_data_with_ssh_tunnel = {
+ "database_name": "test-db-with-ssh-tunnel",
+ "sqlalchemy_uri": modified_sqlalchemy_uri,
+ "ssh_tunnel": ssh_tunnel_properties,
+ }
+
+ database_data = {
+ "database_name": "test-db-with-ssh-tunnel",
+ "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
+ }
+
+ uri = "api/v1/database/"
+ rv = self.client.post(uri, json=database_data)
+ response_create = json.loads(rv.data.decode("utf-8"))
+ self.assertEqual(rv.status_code, 201)
+
+ uri = "api/v1/database/{}".format(response_create.get("id"))
+ rv = self.client.put(uri, json=database_data_with_ssh_tunnel)
+ response = json.loads(rv.data.decode("utf-8"))
+ self.assertEqual(rv.status_code, 400)
+ self.assertEqual(
+ response.get("message"),
+ "A database port is required when connecting via SSH Tunnel.",
+ )
+
+ # Cleanup
+ model = db.session.query(Database).get(response_create.get("id"))
+ db.session.delete(model)
+ db.session.commit()
+
+ @mock.patch(
+ "superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
+ )
+ @mock.patch("superset.commands.database.create.is_feature_enabled")
+ @mock.patch("superset.commands.database.update.is_feature_enabled")
+ @mock.patch("superset.commands.database.ssh_tunnel.delete.is_feature_enabled")
+ @mock.patch(
+ "superset.models.core.Database.get_all_schema_names",
+ )
+ def test_delete_ssh_tunnel(
+ self,
+ mock_test_connection_database_command_run,
+ mock_create_is_feature_enabled,
+ mock_update_is_feature_enabled,
+ mock_delete_is_feature_enabled,
+ mock_get_all_schema_names,
+ ):
+ """
+ Database API: Test deleting a SSH tunnel via Database update
+ """
+ mock_create_is_feature_enabled.return_value = True
+ mock_update_is_feature_enabled.return_value = True
+ mock_delete_is_feature_enabled.return_value = True
+ self.login(username="admin")
+ example_db = get_example_database()
+ if example_db.backend == "sqlite":
+ return
+
+ ssh_tunnel_properties = {
+ "server_address": "123.132.123.1",
+ "server_port": 8080,
+ "username": "foo",
+ "password": "bar",
+ }
+ database_data = {
+ "database_name": "test-db-with-ssh-tunnel",
+ "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
+ }
+ database_data_with_ssh_tunnel = {
+ "database_name": "test-db-with-ssh-tunnel",
+ "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
+ "ssh_tunnel": ssh_tunnel_properties,
+ }
+
+ uri = "api/v1/database/"
+ rv = self.client.post(uri, json=database_data)
+ response = json.loads(rv.data.decode("utf-8"))
+ self.assertEqual(rv.status_code, 201)
+
+ uri = "api/v1/database/{}".format(response.get("id"))
+ rv = self.client.put(uri, json=database_data_with_ssh_tunnel)
+ response_update = json.loads(rv.data.decode("utf-8"))
+ self.assertEqual(rv.status_code, 200)
+
+ model_ssh_tunnel = (
+ db.session.query(SSHTunnel)
+ .filter(SSHTunnel.database_id == response_update.get("id"))
+ .one()
+ )
+ self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id"))
+
+ database_data_with_ssh_tunnel_null = {
+ "database_name": "test-db-with-ssh-tunnel",
+ "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
+ "ssh_tunnel": None,
+ }
+
+ rv = self.client.put(uri, json=database_data_with_ssh_tunnel_null)
+ response_update = json.loads(rv.data.decode("utf-8"))
+ self.assertEqual(rv.status_code, 200)
+
+ model_ssh_tunnel = (
+ db.session.query(SSHTunnel)
+ .filter(SSHTunnel.database_id == response_update.get("id"))
+ .one_or_none()
+ )
+
+ assert model_ssh_tunnel is None
+
+ # Cleanup
+ model = db.session.query(Database).get(response.get("id"))
+ db.session.delete(model)
+ db.session.commit()
+
@mock.patch(
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)
diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py
index 4b05cce637..c80b52931d 100644
--- a/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py
+++ b/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py
@@ -19,7 +19,10 @@
import pytest
from sqlalchemy.orm.session import Session
-from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelInvalidError
+from superset.commands.database.ssh_tunnel.exceptions import (
+ SSHTunnelDatabasePortError,
+ SSHTunnelInvalidError,
+)
def test_create_ssh_tunnel_command() -> None:
@@ -27,7 +30,11 @@ def test_create_ssh_tunnel_command() -> None:
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
- database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
+ database = Database(
+ id=1,
+ database_name="my_database",
+ sqlalchemy_uri="postgresql://u:p@localhost:5432/db",
+ )
properties = {
"database_id": database.id,
@@ -48,7 +55,11 @@ def test_create_ssh_tunnel_command_invalid_params() -> None:
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
- database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
+ database = Database(
+ id=1,
+ database_name="my_database",
+ sqlalchemy_uri="postgresql://u:p@localhost:5432/db",
+ )
# If we are trying to create a tunnel with a private_key_password
# then a private_key is mandatory
@@ -65,3 +76,31 @@ def test_create_ssh_tunnel_command_invalid_params() -> None:
with pytest.raises(SSHTunnelInvalidError) as excinfo:
command.run()
assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")
+
+
+def test_create_ssh_tunnel_command_no_port() -> None:
+ from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
+ from superset.databases.ssh_tunnel.models import SSHTunnel
+ from superset.models.core import Database
+
+ database = Database(
+ id=1,
+ database_name="my_database",
+ sqlalchemy_uri="postgresql://u:p@localhost/db",
+ )
+
+ properties = {
+ "database": database,
+ "server_address": "123.132.123.1",
+ "server_port": "3005",
+ "username": "foo",
+ "password": "bar",
+ }
+
+ command = CreateSSHTunnelCommand(database, properties)
+
+ with pytest.raises(SSHTunnelDatabasePortError) as excinfo:
+ command.run()
+ assert str(excinfo.value) == (
+ "A database port is required when connecting via SSH Tunnel."
+ )
diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py
index 54e54d05da..66684eb8de 100644
--- a/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py
+++ b/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py
@@ -20,11 +20,14 @@ from collections.abc import Iterator
import pytest
from sqlalchemy.orm.session import Session
-from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelInvalidError
+from superset.commands.database.ssh_tunnel.exceptions import (
+ SSHTunnelDatabasePortError,
+ SSHTunnelInvalidError,
+)
@pytest.fixture
-def session_with_data(session: Session) -> Iterator[Session]:
+def session_with_data(request, session: Session) -> Iterator[Session]:
from superset.connectors.sqla.models import SqlaTable
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
@@ -32,7 +35,8 @@ def session_with_data(session: Session) -> Iterator[Session]:
engine = session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
- database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
+ sqlalchemy_uri = getattr(request, "param", "postgresql://u:p@localhost:5432/db")
+ database = Database(database_name="my_database", sqlalchemy_uri=sqlalchemy_uri)
sqla_table = SqlaTable(
table_name="my_sqla_table",
columns=[],
@@ -93,3 +97,28 @@ def test_update_shh_tunnel_invalid_params(session_with_data: Session) -> None:
with pytest.raises(SSHTunnelInvalidError) as excinfo:
command.run()
assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")
+
+
+@pytest.mark.parametrize(
+ "session_with_data", ["postgresql://u:p@localhost/testdb"], indirect=True
+)
+def test_update_shh_tunnel_no_port(session_with_data: Session) -> None:
+ from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand
+ from superset.daos.database import DatabaseDAO
+ from superset.databases.ssh_tunnel.models import SSHTunnel
+
+ result = DatabaseDAO.get_ssh_tunnel(1)
+
+ assert result
+ assert isinstance(result, SSHTunnel)
+ assert 1 == result.database_id
+ assert "Test" == result.server_address
+
+ update_payload = {"server_address": "Test update"}
+ command = UpdateSSHTunnelCommand(1, update_payload)
+
+ with pytest.raises(SSHTunnelDatabasePortError) as excinfo:
+ command.run()
+ assert str(excinfo.value) == (
+ "A database port is required when connecting via SSH Tunnel."
+ )