You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by el...@apache.org on 2022/12/02 18:36:34 UTC
[superset] branch master updated: feat: add databricks form (#21573)
This is an automated email from the ASF dual-hosted git repository.
elizabeth pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new 5c77f1ad2a feat: add databricks form (#21573)
5c77f1ad2a is described below
commit 5c77f1ad2a317254e476c718b74de639c9fc943a
Author: Elizabeth Thompson <es...@gmail.com>
AuthorDate: Fri Dec 2 10:36:27 2022 -0800
feat: add databricks form (#21573)
---
.../DatabaseConnectionForm/CommonParameters.tsx | 46 +++-
.../DatabaseModal/DatabaseConnectionForm/index.tsx | 130 ++++++-----
.../data/database/DatabaseModal/SqlAlchemyForm.tsx | 2 +-
.../data/database/DatabaseModal/index.test.tsx | 35 ++-
.../CRUD/data/database/DatabaseModal/index.tsx | 78 ++++++-
.../src/views/CRUD/data/database/types.ts | 3 +
superset/db_engine_specs/base.py | 2 +-
superset/db_engine_specs/databricks.py | 256 ++++++++++++++++++++-
.../db_engine_specs/databricks_tests.py | 67 ++++++
.../unit_tests/db_engine_specs/test_databricks.py | 135 +++++++++++
10 files changed, 685 insertions(+), 69 deletions(-)
diff --git a/superset-frontend/src/views/CRUD/data/database/DatabaseModal/DatabaseConnectionForm/CommonParameters.tsx b/superset-frontend/src/views/CRUD/data/database/DatabaseModal/DatabaseConnectionForm/CommonParameters.tsx
index a608f1468a..d426cf4cdf 100644
--- a/superset-frontend/src/views/CRUD/data/database/DatabaseModal/DatabaseConnectionForm/CommonParameters.tsx
+++ b/superset-frontend/src/views/CRUD/data/database/DatabaseModal/DatabaseConnectionForm/CommonParameters.tsx
@@ -71,6 +71,29 @@ export const portField = ({
/>
</>
);
+export const httpPath = ({
+ required,
+ changeMethods,
+ getValidation,
+ validationErrors,
+ db,
+}: FieldPropTypes) => {
+ const extraJson = JSON.parse(db?.extra || '{}');
+ return (
+ <ValidatedInput
+ id="http_path"
+ name="http_path"
+ required={required}
+ value={extraJson.engine_params?.connect_args?.http_path}
+ validationMethods={{ onBlur: getValidation }}
+ errorMessage={validationErrors?.http_path}
+ placeholder={t('e.g. sql/protocolv1/o/12345')}
+ label="HTTP Path"
+ onChange={changeMethods.onExtraInputChange}
+ helpText={t('Copy the name of the HTTP Path of your cluster.')}
+ />
+ );
+};
export const databaseField = ({
required,
changeMethods,
@@ -132,6 +155,27 @@ export const passwordField = ({
onChange={changeMethods.onParametersChange}
/>
);
+export const accessTokenField = ({
+ required,
+ changeMethods,
+ getValidation,
+ validationErrors,
+ db,
+ isEditMode,
+}: FieldPropTypes) => (
+ <ValidatedInput
+ id="access_token"
+ name="access_token"
+ required={required}
+ visibilityToggle={!isEditMode}
+ value={db?.parameters?.access_token}
+ validationMethods={{ onBlur: getValidation }}
+ errorMessage={validationErrors?.access_token}
+ placeholder={t('e.g. ********')}
+ label={t('Access token')}
+ onChange={changeMethods.onParametersChange}
+ />
+);
export const displayField = ({
changeMethods,
getValidation,
@@ -150,7 +194,7 @@ export const displayField = ({
label={t('Display Name')}
onChange={changeMethods.onChange}
helpText={t(
- 'Pick a nickname for this database to display as in Superset.',
+ 'Pick a nickname for how the database will display in Superset.',
)}
/>
</>
diff --git a/superset-frontend/src/views/CRUD/data/database/DatabaseModal/DatabaseConnectionForm/index.tsx b/superset-frontend/src/views/CRUD/data/database/DatabaseModal/DatabaseConnectionForm/index.tsx
index e9b562a919..1fa0cbbecb 100644
--- a/superset-frontend/src/views/CRUD/data/database/DatabaseModal/DatabaseConnectionForm/index.tsx
+++ b/superset-frontend/src/views/CRUD/data/database/DatabaseModal/DatabaseConnectionForm/index.tsx
@@ -16,19 +16,21 @@
* specific language governing permissions and limitations
* under the License.
*/
-import React, { FormEvent } from 'react';
+import React, { FormEvent, useEffect } from 'react';
import { SupersetTheme, JsonObject } from '@superset-ui/core';
import { InputProps } from 'antd/lib/input';
import { Form } from 'src/components/Form';
import {
- hostField,
- portField,
+ accessTokenField,
databaseField,
- usernameField,
- passwordField,
displayField,
- queryField,
forceSSLField,
+ hostField,
+ httpPath,
+ passwordField,
+ portField,
+ queryField,
+ usernameField,
} from './CommonParameters';
import { validatedInputField } from './ValidatedInputField';
import { EncryptedField } from './EncryptedField';
@@ -42,6 +44,8 @@ export const FormFieldOrder = [
'database',
'username',
'password',
+ 'access_token',
+ 'http_path',
'database_name',
'credentials_info',
'service_account_info',
@@ -67,6 +71,8 @@ export interface FieldPropTypes {
} & { onParametersUploadFileChange: (value: any) => string } & {
onAddTableCatalog: () => void;
onRemoveTableCatalog: (idx: number) => void;
+ } & {
+ onExtraInputChange: (value: any) => void;
};
validationErrors: JsonObject | null;
getValidation: () => void;
@@ -80,10 +86,12 @@ export interface FieldPropTypes {
const FORM_FIELD_MAP = {
host: hostField,
+ http_path: httpPath,
port: portField,
database: databaseField,
username: usernameField,
password: passwordField,
+ access_token: accessTokenField,
database_name: displayField,
query: queryField,
encryption: forceSSLField,
@@ -96,20 +104,22 @@ const FORM_FIELD_MAP = {
};
const DatabaseConnectionForm = ({
- dbModel: { parameters },
- onParametersChange,
+ dbModel: { parameters, default_driver },
+ db,
+ editNewDb,
+ getPlaceholder,
+ getValidation,
+ isEditMode = false,
+ onAddTableCatalog,
onChange,
- onQueryChange,
+ onExtraInputChange,
+ onParametersChange,
onParametersUploadFileChange,
- onAddTableCatalog,
+ onQueryChange,
onRemoveTableCatalog,
- validationErrors,
- getValidation,
- db,
- isEditMode = false,
+ setDatabaseDriver,
sslForced,
- editNewDb,
- getPlaceholder,
+ validationErrors,
}: {
isEditMode?: boolean;
sslForced: boolean;
@@ -128,50 +138,60 @@ const DatabaseConnectionForm = ({
onParametersUploadFileChange?: (
event: FormEvent<InputProps> | { target: HTMLInputElement },
) => void;
+ onExtraInputChange: (
+ event: FormEvent<InputProps> | { target: HTMLInputElement },
+ ) => void;
onAddTableCatalog: () => void;
onRemoveTableCatalog: (idx: number) => void;
validationErrors: JsonObject | null;
getValidation: () => void;
getPlaceholder?: (field: string) => string | undefined;
-}) => (
- <Form>
- <div
- // @ts-ignore
- css={(theme: SupersetTheme) => [
- formScrollableStyles,
- validatedFormStyles(theme),
- ]}
- >
- {parameters &&
- FormFieldOrder.filter(
- (key: string) =>
- Object.keys(parameters.properties).includes(key) ||
- key === 'database_name',
- ).map(field =>
- FORM_FIELD_MAP[field]({
- required: parameters.required?.includes(field),
- changeMethods: {
- onParametersChange,
- onChange,
- onQueryChange,
- onParametersUploadFileChange,
- onAddTableCatalog,
- onRemoveTableCatalog,
- },
- validationErrors,
- getValidation,
- db,
- key: field,
- field,
- isEditMode,
- sslForced,
- editNewDb,
- placeholder: getPlaceholder ? getPlaceholder(field) : undefined,
- }),
- )}
- </div>
- </Form>
-);
+ setDatabaseDriver: (driver: string) => void;
+}) => {
+ useEffect(() => {
+ setDatabaseDriver(default_driver);
+ }, [default_driver]);
+ return (
+ <Form>
+ <div
+ // @ts-ignore
+ css={(theme: SupersetTheme) => [
+ formScrollableStyles,
+ validatedFormStyles(theme),
+ ]}
+ >
+ {parameters &&
+ FormFieldOrder.filter(
+ (key: string) =>
+ Object.keys(parameters.properties).includes(key) ||
+ key === 'database_name',
+ ).map(field =>
+ FORM_FIELD_MAP[field]({
+ required: parameters.required?.includes(field),
+ changeMethods: {
+ onParametersChange,
+ onChange,
+ onQueryChange,
+ onParametersUploadFileChange,
+ onAddTableCatalog,
+ onRemoveTableCatalog,
+ onExtraInputChange,
+ },
+ validationErrors,
+ getValidation,
+ db,
+ key: field,
+ field,
+ isEditMode,
+ sslForced,
+ editNewDb,
+ placeholder: getPlaceholder ? getPlaceholder(field) : undefined,
+ }),
+ )}
+ </div>
+ </Form>
+ );
+};
export const FormFieldMap = FORM_FIELD_MAP;
export default DatabaseConnectionForm;
diff --git a/superset-frontend/src/views/CRUD/data/database/DatabaseModal/SqlAlchemyForm.tsx b/superset-frontend/src/views/CRUD/data/database/DatabaseModal/SqlAlchemyForm.tsx
index 454070b52a..1fac50d343 100644
--- a/superset-frontend/src/views/CRUD/data/database/DatabaseModal/SqlAlchemyForm.tsx
+++ b/superset-frontend/src/views/CRUD/data/database/DatabaseModal/SqlAlchemyForm.tsx
@@ -101,7 +101,7 @@ const SqlAlchemyTab = ({
</StyledInputContainer>
<Button
onClick={testConnection}
- disabled={testInProgress}
+ loading={testInProgress}
cta
buttonStyle="link"
css={(theme: SupersetTheme) => wideButton(theme)}
diff --git a/superset-frontend/src/views/CRUD/data/database/DatabaseModal/index.test.tsx b/superset-frontend/src/views/CRUD/data/database/DatabaseModal/index.test.tsx
index 6a0173fd56..3bee53480f 100644
--- a/superset-frontend/src/views/CRUD/data/database/DatabaseModal/index.test.tsx
+++ b/superset-frontend/src/views/CRUD/data/database/DatabaseModal/index.test.tsx
@@ -226,6 +226,37 @@ fetchMock.mock(AVAILABLE_DB_ENDPOINT, {
supports_file_upload: false,
},
},
+ {
+ available_drivers: ['connector'],
+ default_driver: 'connector',
+ engine: 'databricks',
+ name: 'Databricks',
+ parameters: {
+ properties: {
+ access_token: {
+ type: 'string',
+ },
+ database: {
+ type: 'string',
+ },
+ host: {
+ type: 'string',
+ },
+ http_path: {
+ type: 'string',
+ },
+ port: {
+ format: 'int32',
+ type: 'integer',
+ },
+ },
+ required: ['access_token', 'database', 'host', 'http_path', 'port'],
+ type: 'object',
+ },
+ preferred: true,
+ sqlalchemy_uri_placeholder:
+ 'databricks+connector://token:{access_token}@{host}:{port}/{database_name}',
+ },
],
});
fetchMock.post(VALIDATE_PARAMS_ENDPOINT, {
@@ -238,6 +269,7 @@ const databaseFixture: DatabaseObject = {
database_name: 'Postgres',
name: 'PostgresDB',
is_managed_externally: false,
+ driver: 'psycopg2',
};
describe('DatabaseModal', () => {
@@ -355,8 +387,9 @@ describe('DatabaseModal', () => {
});
// there should be a footer but it should not have any buttons in it
expect(footer[0]).toBeEmptyDOMElement();
+
// This is how many preferred databases are rendered
- expect(preferredDbIcon).toHaveLength(4);
+ expect(preferredDbIcon).toHaveLength(5);
});
test('renders the "Basic" tab of SQL Alchemy form (step 2 of 2) correctly', async () => {
diff --git a/superset-frontend/src/views/CRUD/data/database/DatabaseModal/index.tsx b/superset-frontend/src/views/CRUD/data/database/DatabaseModal/index.tsx
index 003f9d64be..faa3f48cad 100644
--- a/superset-frontend/src/views/CRUD/data/database/DatabaseModal/index.tsx
+++ b/superset-frontend/src/views/CRUD/data/database/DatabaseModal/index.tsx
@@ -132,19 +132,20 @@ interface DatabaseModalProps {
}
export enum ActionType {
+ addTableCatalogSheet,
configMethodChange,
dbSelected,
+ driverChange,
editorChange,
+ extraEditorChange,
+ extraInputChange,
fetched,
inputChange,
parametersChange,
+ queryChange,
+ removeTableCatalogSheet,
reset,
textChange,
- extraInputChange,
- extraEditorChange,
- addTableCatalogSheet,
- removeTableCatalogSheet,
- queryChange,
}
interface DBReducerPayloadType {
@@ -197,6 +198,10 @@ export type DBReducerActionType =
engine?: string;
configuration_method: CONFIGURATION_METHOD;
};
+ }
+ | {
+ type: ActionType.driverChange;
+ payload: string;
};
const StyledBtns = styled.div`
@@ -254,6 +259,19 @@ export function dbReducer(
}),
};
}
+ if (action.payload.name === 'http_path') {
+ return {
+ ...trimmedState,
+ extra: JSON.stringify({
+ ...extraJson,
+ engine_params: {
+ connect_args: {
+ [action.payload.name]: action.payload.value?.trim(),
+ },
+ },
+ }),
+ };
+ }
return {
...trimmedState,
extra: JSON.stringify({
@@ -408,6 +426,12 @@ export function dbReducer(
...action.payload,
};
+ case ActionType.driverChange:
+ return {
+ ...trimmedState,
+ driver: action.payload,
+ };
+
case ActionType.reset:
default:
return null;
@@ -578,10 +602,17 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
},
});
}
+
+ // make sure that button spinner animates
+ setLoading(true);
const errors = await getValidation(dbToUpdate, true);
if ((validationErrors && !isEmpty(validationErrors)) || errors) {
+ setLoading(false);
return;
}
+ setLoading(false);
+ // end spinner animation
+
const parameters_schema = isEditMode
? dbToUpdate.parameters_schema?.properties
: dbModel?.parameters.properties;
@@ -873,6 +904,7 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
key="submit"
buttonStyle="primary"
onClick={onSave}
+ loading={isLoading}
>
{t('Connect')}
</StyledFooterButton>
@@ -890,6 +922,7 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
buttonStyle="primary"
onClick={onSave}
data-test="modal-confirm-button"
+ loading={isLoading}
>
{t('Finish')}
</StyledFooterButton>
@@ -909,6 +942,7 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
buttonStyle="primary"
onClick={onSave}
disabled={handleDisableOnImport()}
+ loading={isLoading}
>
{t('Connect')}
</StyledFooterButton>
@@ -929,6 +963,7 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
buttonStyle="primary"
onClick={onSave}
disabled={db?.is_managed_externally}
+ loading={isLoading}
tooltip={
db?.is_managed_externally
? t(
@@ -966,8 +1001,8 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
useEffect(() => {
if (show) {
setTabKey(DEFAULT_TAB_KEY);
- getAvailableDbs();
setLoading(true);
+ getAvailableDbs();
}
if (databaseId && show) {
fetchDB();
@@ -1254,6 +1289,9 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
sslForced={sslForced}
dbModel={dbModel}
db={db as DatabaseObject}
+ setDatabaseDriver={(driver: string) => {
+ onChange(ActionType.driverChange, driver);
+ }}
onParametersChange={({ target }: { target: HTMLInputElement }) =>
onChange(ActionType.parametersChange, {
type: target.type,
@@ -1262,6 +1300,12 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
value: target.value,
})
}
+ onExtraInputChange={({ target }: { target: HTMLInputElement }) =>
+ onChange(ActionType.extraInputChange, {
+ name: target.name,
+ value: target.value,
+ })
+ }
onChange={({ target }: { target: HTMLInputElement }) =>
onChange(ActionType.textChange, {
name: target.name,
@@ -1419,6 +1463,9 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
sslForced={sslForced}
dbModel={dbModel}
db={db as DatabaseObject}
+ setDatabaseDriver={(driver: string) => {
+ onChange(ActionType.driverChange, driver);
+ }}
onParametersChange={({ target }: { target: HTMLInputElement }) =>
onChange(ActionType.parametersChange, {
type: target.type,
@@ -1427,6 +1474,12 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
value: target.value,
})
}
+ onExtraInputChange={({ target }: { target: HTMLInputElement }) =>
+ onChange(ActionType.extraInputChange, {
+ name: target.name,
+ value: target.value,
+ })
+ }
onChange={({ target }: { target: HTMLInputElement }) =>
onChange(ActionType.textChange, {
name: target.name,
@@ -1606,6 +1659,9 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
db={db}
sslForced={sslForced}
dbModel={dbModel}
+ setDatabaseDriver={(driver: string) => {
+ onChange(ActionType.driverChange, driver);
+ }}
onAddTableCatalog={() => {
setDB({ type: ActionType.addTableCatalogSheet });
}}
@@ -1615,6 +1671,16 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
value: target.value,
})
}
+ onExtraInputChange={({
+ target,
+ }: {
+ target: HTMLInputElement;
+ }) =>
+ onChange(ActionType.extraInputChange, {
+ name: target.name,
+ value: target.value,
+ })
+ }
onRemoveTableCatalog={(idx: number) => {
setDB({
type: ActionType.removeTableCatalogSheet,
diff --git a/superset-frontend/src/views/CRUD/data/database/types.ts b/superset-frontend/src/views/CRUD/data/database/types.ts
index 373e0dbf83..18b7eabff2 100644
--- a/superset-frontend/src/views/CRUD/data/database/types.ts
+++ b/superset-frontend/src/views/CRUD/data/database/types.ts
@@ -34,6 +34,7 @@ export type DatabaseObject = {
configuration_method: CONFIGURATION_METHOD;
created_by?: null | DatabaseUser;
database_name: string;
+ driver: string;
engine?: string;
extra?: string;
id?: number;
@@ -41,6 +42,7 @@ export type DatabaseObject = {
paramProperties?: Record<string, any>;
sqlalchemy_uri?: string;
parameters?: {
+ access_token?: string;
database_name?: string;
host?: string;
port?: number;
@@ -90,6 +92,7 @@ export type DatabaseObject = {
};
export type DatabaseForm = {
+ default_driver: string;
engine: string;
name: string;
parameters: {
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 87951d396e..3e2c0f56ba 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -1756,7 +1756,7 @@ class BasicParametersMixin:
def build_sqlalchemy_uri( # pylint: disable=unused-argument
cls,
parameters: BasicParametersType,
- encryted_extra: Optional[Dict[str, str]] = None,
+ encrypted_extra: Optional[Dict[str, str]] = None,
) -> str:
# make a copy so that we don't update the original
query = parameters.get("query", {}).copy()
diff --git a/superset/db_engine_specs/databricks.py b/superset/db_engine_specs/databricks.py
index 8dce8a5940..7ebe6ab1ab 100644
--- a/superset/db_engine_specs/databricks.py
+++ b/superset/db_engine_specs/databricks.py
@@ -15,18 +15,83 @@
# specific language governing permissions and limitations
# under the License.
+import json
from datetime import datetime
-from typing import Any, Dict, Optional, Set, TYPE_CHECKING
+from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING
+from apispec import APISpec
+from apispec.ext.marshmallow import MarshmallowPlugin
+from flask_babel import gettext as __
+from marshmallow import fields, Schema
+from marshmallow.validate import Range
from sqlalchemy.engine.reflection import Inspector
+from sqlalchemy.engine.url import URL
+from typing_extensions import TypedDict
from superset.constants import USER_AGENT
-from superset.db_engine_specs.base import BaseEngineSpec
+from superset.databases.utils import make_url_safe
+from superset.db_engine_specs.base import BaseEngineSpec, BasicParametersMixin
from superset.db_engine_specs.hive import HiveEngineSpec
+from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
+from superset.utils.network import is_hostname_valid, is_port_open
if TYPE_CHECKING:
from superset.models.core import Database
+
+class DatabricksParametersSchema(Schema):
+ """
+ This is the list of fields that are expected
+ from the client in order to build the sqlalchemy string
+ """
+
+ access_token = fields.Str(required=True)
+ host = fields.Str(required=True)
+ port = fields.Integer(
+ required=True,
+ description=__("Database port"),
+ validate=Range(min=0, max=2**16, max_inclusive=False),
+ )
+ database = fields.Str(required=True)
+ encryption = fields.Boolean(
+ required=False, description=__("Use an encrypted connection to the database")
+ )
+
+
+class DatabricksPropertiesSchema(DatabricksParametersSchema):
+ """
+ This is the list of fields expected
+ for successful database creation execution
+ """
+
+ http_path = fields.Str(required=True)
+
+
+class DatabricksParametersType(TypedDict):
+ """
+ The parameters are all the keys that do
+ not exist on the Database model.
+ These are used to build the sqlalchemy uri
+ """
+
+ access_token: str
+ host: str
+ port: int
+ database: str
+ encryption: bool
+
+
+class DatabricksPropertiesType(TypedDict):
+ """
+ All properties that need to be available to
+ this engine in order to create a connection
+ if the dynamic form is used
+ """
+
+ parameters: DatabricksParametersType
+ extra: str
+
+
time_grain_expressions = {
None: "{col}",
"PT1S": "date_trunc('second', {col})",
@@ -78,13 +143,21 @@ class DatabricksODBCEngineSpec(BaseEngineSpec):
return HiveEngineSpec.epoch_to_dttm()
-class DatabricksNativeEngineSpec(DatabricksODBCEngineSpec):
- engine_name = "Databricks Native Connector"
+class DatabricksNativeEngineSpec(DatabricksODBCEngineSpec, BasicParametersMixin):
+ engine_name = "Databricks"
engine = "databricks"
drivers = {"connector": "Native all-purpose driver"}
default_driver = "connector"
+ parameters_schema = DatabricksParametersSchema()
+ properties_schema = DatabricksPropertiesSchema()
+
+ sqlalchemy_uri_placeholder = (
+ "databricks+connector://token:{access_token}@{host}:{port}/{database_name}"
+ )
+ encryption_parameters = {"ssl": "1"}
+
@staticmethod
def get_extra_params(database: "Database") -> Dict[str, Any]:
"""
@@ -107,3 +180,178 @@ class DatabricksNativeEngineSpec(DatabricksODBCEngineSpec):
return super().get_table_names(
database, inspector, schema
) - cls.get_view_names(database, inspector, schema)
+
+ @classmethod
+ def build_sqlalchemy_uri( # type: ignore
+ cls, parameters: DatabricksParametersType, *_
+ ) -> str:
+
+ query = {}
+ if parameters.get("encryption"):
+ if not cls.encryption_parameters:
+ raise Exception("Unable to build a URL with encryption enabled")
+ query.update(cls.encryption_parameters)
+
+ return str(
+ URL(
+ f"{cls.engine}+{cls.default_driver}".rstrip("+"),
+ username="token",
+ password=parameters.get("access_token"),
+ host=parameters["host"],
+ port=parameters["port"],
+ database=parameters["database"],
+ query=query,
+ )
+ )
+
+ @classmethod
+ def extract_errors(
+ cls, ex: Exception, context: Optional[Dict[str, Any]] = None
+ ) -> List[SupersetError]:
+ raw_message = cls._extract_error_message(ex)
+
+ context = context or {}
+ context = {
+ "host": context["hostname"],
+ "access_token": context["password"],
+ "port": context["port"],
+ "username": context["username"],
+ "database": context["database"],
+ }
+ for regex, (message, error_type, extra) in cls.custom_errors.items():
+ match = regex.search(raw_message)
+ if match:
+ params = {**context, **match.groupdict()}
+ extra["engine_name"] = cls.engine_name
+ return [
+ SupersetError(
+ error_type=error_type,
+ message=message % params,
+ level=ErrorLevel.ERROR,
+ extra=extra,
+ )
+ ]
+
+ return [
+ SupersetError(
+ error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
+ message=cls._extract_error_message(ex),
+ level=ErrorLevel.ERROR,
+ extra={"engine_name": cls.engine_name},
+ )
+ ]
+
+ @classmethod
+ def get_parameters_from_uri( # type: ignore
+ cls, uri: str, *_, **__
+ ) -> DatabricksParametersType:
+ url = make_url_safe(uri)
+ encryption = all(
+ item in url.query.items() for item in cls.encryption_parameters.items()
+ )
+ return {
+ "access_token": url.password,
+ "host": url.host,
+ "port": url.port,
+ "database": url.database,
+ "encryption": encryption,
+ }
+
+ @classmethod
+ def validate_parameters( # type: ignore
+ cls,
+ properties: DatabricksPropertiesType,
+ ) -> List[SupersetError]:
+ errors: List[SupersetError] = []
+ required = {"access_token", "host", "port", "database", "extra"}
+ extra = json.loads(properties.get("extra", "{}"))
+ engine_params = extra.get("engine_params", {})
+ connect_args = engine_params.get("connect_args", {})
+ parameters = {
+ **properties,
+ **properties.get("parameters", {}),
+ }
+ if connect_args.get("http_path"):
+ parameters["http_path"] = connect_args.get("http_path")
+
+ present = {key for key in parameters if parameters.get(key, ())}
+ missing = sorted(required - present)
+
+ if missing:
+ errors.append(
+ SupersetError(
+ message=f'One or more parameters are missing: {", ".join(missing)}',
+ error_type=SupersetErrorType.CONNECTION_MISSING_PARAMETERS_ERROR,
+ level=ErrorLevel.WARNING,
+ extra={"missing": missing},
+ ),
+ )
+
+ host = parameters.get("host", None)
+ if not host:
+ return errors
+
+ if not is_hostname_valid(host): # type: ignore
+ errors.append(
+ SupersetError(
+ message="The hostname provided can't be resolved.",
+ error_type=SupersetErrorType.CONNECTION_INVALID_HOSTNAME_ERROR,
+ level=ErrorLevel.ERROR,
+ extra={"invalid": ["host"]},
+ ),
+ )
+ return errors
+
+ port = parameters.get("port", None)
+ if not port:
+ return errors
+ try:
+ port = int(port) # type: ignore
+ except (ValueError, TypeError):
+ errors.append(
+ SupersetError(
+ message="Port must be a valid integer.",
+ error_type=SupersetErrorType.CONNECTION_INVALID_PORT_ERROR,
+ level=ErrorLevel.ERROR,
+ extra={"invalid": ["port"]},
+ ),
+ )
+ if not (isinstance(port, int) and 0 <= port < 2**16):
+ errors.append(
+ SupersetError(
+ message=(
+ "The port must be an integer between 0 and 65535 "
+ "(inclusive)."
+ ),
+ error_type=SupersetErrorType.CONNECTION_INVALID_PORT_ERROR,
+ level=ErrorLevel.ERROR,
+ extra={"invalid": ["port"]},
+ ),
+ )
+ elif not is_port_open(host, port): # type: ignore
+ errors.append(
+ SupersetError(
+ message="The port is closed.",
+ error_type=SupersetErrorType.CONNECTION_PORT_CLOSED_ERROR,
+ level=ErrorLevel.ERROR,
+ extra={"invalid": ["port"]},
+ ),
+ )
+ return errors
+
+ @classmethod
+ def parameters_json_schema(cls) -> Any:
+ """
+ Return configuration parameters as OpenAPI.
+ """
+ if not cls.properties_schema:
+ return None
+
+ spec = APISpec(
+ title="Database Parameters",
+ version="1.0.0",
+ openapi_version="3.0.2",
+ plugins=[MarshmallowPlugin()],
+ )
+ spec.components.schema(cls.__name__, schema=cls.properties_schema)
+ return spec.to_dict()["components"]["schemas"][cls.__name__]
diff --git a/tests/integration_tests/db_engine_specs/databricks_tests.py b/tests/integration_tests/db_engine_specs/databricks_tests.py
new file mode 100644
index 0000000000..b399e41fd3
--- /dev/null
+++ b/tests/integration_tests/db_engine_specs/databricks_tests.py
@@ -0,0 +1,67 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from textwrap import dedent
+from unittest import mock
+
+from sqlalchemy import column, literal_column
+
+from superset.constants import USER_AGENT
+from superset.db_engine_specs import get_engine_spec
+from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec
+from tests.integration_tests.db_engine_specs.base_tests import (
+ assert_generic_types,
+ TestDbEngineSpec,
+)
+from tests.integration_tests.fixtures.certificates import ssl_certificate
+from tests.integration_tests.fixtures.database import default_db_extra
+
+
+class TestDatabricksDbEngineSpec(TestDbEngineSpec):
+ def test_get_engine_spec(self):
+ """
+ DB Eng Specs (databricks): Test "databricks" in engine spec
+ """
+ assert get_engine_spec("databricks", "connector").engine == "databricks"
+ assert get_engine_spec("databricks", "pyodbc").engine == "databricks"
+ assert get_engine_spec("databricks", "pyhive").engine == "databricks"
+
+ def test_extras_without_ssl(self):
+ db = mock.Mock()
+ db.extra = default_db_extra
+ db.server_cert = None
+ extras = DatabricksNativeEngineSpec.get_extra_params(db)
+ assert "connect_args" not in extras["engine_params"]
+
+ def test_extras_with_user_agent(self):
+ db = mock.Mock()
+ db.extra = default_db_extra
+ extras = DatabricksNativeEngineSpec.get_extra_params(db)
+ _, user_agent = extras["http_headers"][0]
+ user_agent_entry = extras["_user_agent_entry"]
+ assert user_agent == USER_AGENT
+ assert user_agent_entry == USER_AGENT
+
+ def test_extras_with_ssl_custom(self):
+ db = mock.Mock()
+ db.extra = default_db_extra.replace(
+ '"engine_params": {}',
+ '"engine_params": {"connect_args": {"ssl": "1"}}',
+ )
+ db.server_cert = ssl_certificate
+ extras = DatabricksNativeEngineSpec.get_extra_params(db)
+ connect_args = extras["engine_params"]["connect_args"]
+ assert connect_args["ssl"] == "1"
diff --git a/tests/unit_tests/db_engine_specs/test_databricks.py b/tests/unit_tests/db_engine_specs/test_databricks.py
new file mode 100644
index 0000000000..0cc0907f4d
--- /dev/null
+++ b/tests/unit_tests/db_engine_specs/test_databricks.py
@@ -0,0 +1,135 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-argument, import-outside-toplevel, protected-access
+
+import json
+
+from superset.utils.core import GenericDataType
+from tests.integration_tests.db_engine_specs.base_tests import assert_generic_types
+
+
+def test_get_parameters_from_uri() -> None:
+ """
+ Test that the result from ``get_parameters_from_uri`` is JSON serializable.
+ """
+ from superset.db_engine_specs.databricks import (
+ DatabricksNativeEngineSpec,
+ DatabricksParametersType,
+ )
+
+ parameters = DatabricksNativeEngineSpec.get_parameters_from_uri(
+ "databricks+connector://token:abc12345@my_hostname:1234/test"
+ )
+ assert parameters == DatabricksParametersType(
+ {
+ "access_token": "abc12345",
+ "host": "my_hostname",
+ "port": 1234,
+ "database": "test",
+ "encryption": False,
+ }
+ )
+ assert json.loads(json.dumps(parameters)) == parameters
+
+
+def test_build_sqlalchemy_uri() -> None:
+ """
+ test that the parameters are can correctly be compiled into a
+ sqlalchemy_uri
+ """
+ from superset.db_engine_specs.databricks import (
+ DatabricksNativeEngineSpec,
+ DatabricksParametersType,
+ )
+
+ parameters = DatabricksParametersType(
+ {
+ "access_token": "abc12345",
+ "host": "my_hostname",
+ "port": 1234,
+ "database": "test",
+ "encryption": False,
+ }
+ )
+ encrypted_extra = None
+ sqlalchemy_uri = DatabricksNativeEngineSpec.build_sqlalchemy_uri(
+ parameters, encrypted_extra
+ )
+ assert sqlalchemy_uri == (
+ "databricks+connector://token:abc12345@my_hostname:1234/test"
+ )
+
+
+def test_parameters_json_schema() -> None:
+ """
+ test that the parameters schema can be converted to json
+ """
+ from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec
+
+ json_schema = DatabricksNativeEngineSpec.parameters_json_schema()
+
+ assert json_schema == {
+ "type": "object",
+ "properties": {
+ "access_token": {"type": "string"},
+ "database": {"type": "string"},
+ "encryption": {
+ "description": "Use an encrypted connection to the database",
+ "type": "boolean",
+ },
+ "host": {"type": "string"},
+ "http_path": {"type": "string"},
+ "port": {
+ "description": "Database port",
+ "format": "int32",
+ "maximum": 65536,
+ "minimum": 0,
+ "type": "integer",
+ },
+ },
+ "required": ["access_token", "database", "host", "http_path", "port"],
+ }
+
+
+def test_generic_type() -> None:
+ """
+ assert that generic types match
+ """
+ from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec
+
+ type_expectations = (
+ # Numeric
+ ("SMALLINT", GenericDataType.NUMERIC),
+ ("INTEGER", GenericDataType.NUMERIC),
+ ("BIGINT", GenericDataType.NUMERIC),
+ ("DECIMAL", GenericDataType.NUMERIC),
+ ("NUMERIC", GenericDataType.NUMERIC),
+ ("REAL", GenericDataType.NUMERIC),
+ ("DOUBLE PRECISION", GenericDataType.NUMERIC),
+ ("MONEY", GenericDataType.NUMERIC),
+ # String
+ ("CHAR", GenericDataType.STRING),
+ ("VARCHAR", GenericDataType.STRING),
+ ("TEXT", GenericDataType.STRING),
+ # Temporal
+ ("DATE", GenericDataType.TEMPORAL),
+ ("TIMESTAMP", GenericDataType.TEMPORAL),
+ ("TIME", GenericDataType.TEMPORAL),
+ # Boolean
+ ("BOOLEAN", GenericDataType.BOOLEAN),
+ )
+ assert_generic_types(DatabricksNativeEngineSpec, type_expectations)