You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by di...@apache.org on 2023/04/04 10:36:28 UTC
[superset] 01/03: WIP
This is an automated email from the ASF dual-hosted git repository.
diegopucci pushed a commit to branch feat/sqllab-natural-language
in repository https://gitbox.apache.org/repos/asf/superset.git
commit bf4db70b83a9a2a622ab5083ebcf78375793aa3c
Author: geido <di...@gmail.com>
AuthorDate: Thu Feb 23 18:53:32 2023 +0100
WIP
---
.../SqlLab/components/AceEditorWrapper/index.tsx | 4 +-
.../src/SqlLab/components/SqlEditor/index.jsx | 141 +++++++++++++++++++--
superset-frontend/src/components/index.ts | 1 +
superset/sqllab/api.py | 61 +++++++++
superset/sqllab/schemas.py | 3 +-
5 files changed, 195 insertions(+), 15 deletions(-)
diff --git a/superset-frontend/src/SqlLab/components/AceEditorWrapper/index.tsx b/superset-frontend/src/SqlLab/components/AceEditorWrapper/index.tsx
index 0dd3385ea5..412fe48333 100644
--- a/superset-frontend/src/SqlLab/components/AceEditorWrapper/index.tsx
+++ b/superset-frontend/src/SqlLab/components/AceEditorWrapper/index.tsx
@@ -57,6 +57,7 @@ type AceEditorWrapperProps = {
extendedTables?: Array<{ name: string; columns: any[] }>;
height: string;
hotkeys: HotKey[];
+ initialSql?: string;
};
const StyledAceEditor = styled(AceEditor)`
@@ -90,6 +91,7 @@ const AceEditorWrapper = ({
extendedTables = [],
height,
hotkeys,
+ initialSql,
}: AceEditorWrapperProps) => {
const dispatch = useDispatch();
@@ -103,7 +105,7 @@ const AceEditorWrapper = ({
'validationResult',
'schema',
]);
- const currentSql = queryEditor.sql ?? '';
+ const currentSql = initialSql || queryEditor.sql || '';
const functionNames = queryEditor.functionNames ?? [];
const schemas = queryEditor.schemaOptions ?? [];
const tables = queryEditor.tableOptions ?? [];
diff --git a/superset-frontend/src/SqlLab/components/SqlEditor/index.jsx b/superset-frontend/src/SqlLab/components/SqlEditor/index.jsx
index d7626c8cbf..d517b4e939 100644
--- a/superset-frontend/src/SqlLab/components/SqlEditor/index.jsx
+++ b/superset-frontend/src/SqlLab/components/SqlEditor/index.jsx
@@ -31,8 +31,8 @@ import Mousetrap from 'mousetrap';
import Button from 'src/components/Button';
import Timer from 'src/components/Timer';
import ResizableSidebar from 'src/components/ResizableSidebar';
-import { AntdDropdown, AntdSwitch } from 'src/components';
-import { Input } from 'src/components/Input';
+import { AntdDropdown, AntdSwitch, Result } from 'src/components';
+import { Input, TextArea } from 'src/components/Input';
import { Menu } from 'src/components/Menu';
import Icons from 'src/components/Icons';
import { detectOS } from 'src/utils/common';
@@ -86,6 +86,7 @@ import SqlEditorLeftBar from '../SqlEditorLeftBar';
import AceEditorWrapper from '../AceEditorWrapper';
import RunQueryActionButton from '../RunQueryActionButton';
import QueryLimitSelect from '../QueryLimitSelect';
+import { SupersetClient } from '@superset-ui/core';
const bootstrapData = getBootstrapData();
const validatorMap =
@@ -222,10 +223,13 @@ const SqlEditor = ({
database: databases[dbId],
latestQuery: queries[latestQueryId],
hideLeftBar,
+ tables,
};
},
);
+ console.log('tables', tables);
+
const [height, setHeight] = useState(0);
const [autorun, setAutorun] = useState(queryEditor.autorun);
const [ctas, setCtas] = useState('');
@@ -640,6 +644,112 @@ const SqlEditor = ({
);
};
+ // NLP implementation
+ const [editorType, setEditorType] = useState('sql');
+ const [NLPQuery, setNLPQuery] = useState('');
+ const [NLPResult, setNLPResult] = useState('');
+ const [NLPLoading, setNLPLoading] = useState(false);
+ const handleNLPGeneration = async () => {
+ setNLPLoading(true);
+ let tablesContext = "# Given the following table/s definition\n\n";
+ for(let t = 0; t < tables.length; t += 1) {
+ const table = tables[t];
+ if (table?.columns?.length) {
+ tablesContext += `# Table ${table.name}, columns = [`;
+ for(let c = 0; c < table.columns.length; c += 1) {
+ const col = table.columns[c];
+ tablesContext += col.name;
+ tablesContext += table.columns.at(-1)?.name === col.name ? "]" : ", "
+ }
+ tablesContext += `\n\n`;
+ }
+ }
+ tablesContext += `# Create one valid SQL SELECT statement with the following constraints\n`;
+ tablesContext += `# For example: SELECT column FROM table;\n`;
+ tablesContext += `# Do NOT generate more tha one SELECT statement\n`;
+ tablesContext += `# Do NOT generate any text other than one valid SELECT statement\n`;
+ tablesContext += `# Do ONLY use SELECT\n`;
+ tablesContext += `# Respond with a SQL statement to select ${NLPQuery} from the given tables ->`;
+ const postPayload = {
+ prompt: tablesContext,
+ }
+ SupersetClient.post({
+ endpoint: "api/v1/sqllab/nlp",
+ headers: { 'Content-Type': 'application/json' },
+ body: JSON.stringify(postPayload),
+ parseMethod: 'json-bigint',
+ })
+ .then(({ json }) => {
+ setNLPResult(json.result);
+ setEditorType('sql')
+ setNLPLoading(false);
+ })
+ .catch(() => {
+ setNLPLoading(false);
+ });
+
+ console.log(tablesContext);
+ }
+ const renderNLPMenu = (
+ <Menu
+ mode="horizontal"
+ defaultSelectedKeys={[editorType]}
+ css={css`
+ margin-bottom: 20px;
+ `}
+ >
+ <Menu.Item
+ disabled={NLPLoading}
+ key="sql"
+ icon={<Icons.ConsoleSqlOutlined />}
+ onClick={({ key }) => setEditorType(key)}
+ >
+ SQL
+ </Menu.Item>
+ <Menu.Item
+ disabled={NLPLoading}
+ key="nlp"
+ icon={<Icons.ExperimentOutlined />}
+ onClick={({ key }) => setEditorType(key)}
+ >
+ Natural language
+ </Menu.Item>
+ </Menu>
+ );
+
+ const renderNLPBottomBar = (
+ tables.length > 0 ? <Button
+ type="primary"
+ size="large"
+ onClick={handleNLPGeneration}
+ css={css`
+ margin-top: 20px;
+ margin-bottom: 20px;
+ padding: 20px;
+ width: 250px;
+ `}
+ >
+ {!NLPLoading ? <Icons.ExperimentOutlined /> : <Icons.LoadingOutlined />}{' '}
+ Generate query
+ </Button> : null
+ );
+
+ const renderNLPForm =
+ tables.length > 0 ? (
+ <TextArea
+ disabled={NLPLoading}
+ rows={6}
+ onChange={e => setNLPQuery(e.target.value)}
+ placeholder="Get all fruits from the tree..."
+ />
+ ) : (
+ <Result
+ status="warning"
+ title={<small>Please select one or more tables</small>}
+ subTitle="Select one or more tables on the left pane for the AI algorithm to gain context about your query"
+ />
+ );
+
const queryPane = () => {
const hotkeys = getHotkeyConfig();
const { aceEditorHeight, southPaneHeight } =
@@ -657,17 +767,22 @@ const SqlEditor = ({
onDragEnd={onResizeEnd}
>
<div ref={northPaneRef} className="north-pane">
- <AceEditorWrapper
- autocomplete={autocompleteEnabled}
- onBlur={setQueryEditorAndSaveSql}
- onChange={onSqlChanged}
- queryEditorId={queryEditor.id}
- database={database}
- extendedTables={tables}
- height={`${aceEditorHeight}px`}
- hotkeys={hotkeys}
- />
- {renderEditorBottomBar(hotkeys)}
+ {renderNLPMenu}
+ {editorType === 'sql' && (
+ <AceEditorWrapper
+ autocomplete={autocompleteEnabled}
+ onBlur={setQueryEditorAndSaveSql}
+ onChange={onSqlChanged}
+ queryEditorId={queryEditor.id}
+ database={database}
+ extendedTables={tables}
+ height={`${aceEditorHeight}px`}
+ hotkeys={hotkeys}
+ initialSql={NLPResult}
+ />
+ )}
+ {editorType === 'nlp' && renderNLPForm}
+ {editorType === 'sql' ? renderEditorBottomBar(hotkeys) : renderNLPBottomBar}
</div>
<SouthPane
queryEditorId={queryEditor.id}
diff --git a/superset-frontend/src/components/index.ts b/superset-frontend/src/components/index.ts
index bfa341a9dd..84b10cc979 100644
--- a/superset-frontend/src/components/index.ts
+++ b/superset-frontend/src/components/index.ts
@@ -46,6 +46,7 @@ export {
Tree,
Typography,
Upload,
+ Result,
} from 'antd';
/*
diff --git a/superset/sqllab/api.py b/superset/sqllab/api.py
index 283c3ab638..dd659d753c 100644
--- a/superset/sqllab/api.py
+++ b/superset/sqllab/api.py
@@ -41,6 +41,7 @@ from superset.sqllab.execution_context_convertor import ExecutionContextConverto
from superset.sqllab.query_render import SqlQueryRenderImpl
from superset.sqllab.schemas import (
ExecutePayloadSchema,
+ NLPPayloadSchema,
QueryExecutionResponseSchema,
sql_lab_get_results_schema,
)
@@ -55,6 +56,7 @@ from superset.superset_typing import FlaskResponse
from superset.utils import core as utils
from superset.views.base import json_success
from superset.views.base_api import BaseSupersetApi, requires_json, statsd_metrics
+import openai
config = app.config
logger = logging.getLogger(__name__)
@@ -75,10 +77,69 @@ class SqlLabRestApi(BaseSupersetApi):
}
openapi_spec_tag = "SQL Lab"
openapi_spec_component_schemas = (
+ NLPPayloadSchema,
ExecutePayloadSchema,
QueryExecutionResponseSchema,
)
+ @expose("/nlp/", methods=["POST"])
+ # @protect()
+ # @statsd_metrics
+ # @requires_json
+ def execute_completion(self) -> FlaskResponse:
+ """Executes a SQL query
+ ---
+ post:
+ description: >-
+ Starts the execution of a SQL query
+ requestBody:
+ description: SQL query and params
+ required: true
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/NLPPayloadSchema'
+ responses:
+ 200:
+ description: Query execution result
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/NLPPayloadSchema'
+ 202:
+ description: Query execution result, query still running
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/NLPPayloadSchema'
+ 400:
+ $ref: '#/components/responses/400'
+ 401:
+ $ref: '#/components/responses/401'
+ 403:
+ $ref: '#/components/responses/403'
+ 404:
+ $ref: '#/components/responses/404'
+ 500:
+ $ref: '#/components/responses/500'
+ """
+ try:
+ openai.api_key = "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"
+ requestPrompt = request.json
+ completion = openai.Completion.create(
+ engine="code-davinci-002",
+ prompt=requestPrompt['prompt'],
+ max_tokens=100,
+ temperature=0
+ )
+ choice = {**completion.choices[0]}
+ payload = {
+ 'result': choice['text'],
+ }
+ return self.response(200, **payload)
+ except Exception as e:
+ print(e);
+
@expose("/results/")
@protect()
@statsd_metrics
diff --git a/superset/sqllab/schemas.py b/superset/sqllab/schemas.py
index f238fda5c9..0ca760b757 100644
--- a/superset/sqllab/schemas.py
+++ b/superset/sqllab/schemas.py
@@ -23,7 +23,8 @@ sql_lab_get_results_schema = {
},
"required": ["key"],
}
-
+class NLPPayloadSchema(Schema):
+ prompt = fields.String(required=True)
class ExecutePayloadSchema(Schema):
database_id = fields.Integer(required=True)