You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by di...@apache.org on 2023/04/04 10:36:28 UTC

[superset] 01/03: WIP

This is an automated email from the ASF dual-hosted git repository.

diegopucci pushed a commit to branch feat/sqllab-natural-language
in repository https://gitbox.apache.org/repos/asf/superset.git

commit bf4db70b83a9a2a622ab5083ebcf78375793aa3c
Author: geido <di...@gmail.com>
AuthorDate: Thu Feb 23 18:53:32 2023 +0100

    WIP
---
 .../SqlLab/components/AceEditorWrapper/index.tsx   |   4 +-
 .../src/SqlLab/components/SqlEditor/index.jsx      | 141 +++++++++++++++++++--
 superset-frontend/src/components/index.ts          |   1 +
 superset/sqllab/api.py                             |  61 +++++++++
 superset/sqllab/schemas.py                         |   3 +-
 5 files changed, 195 insertions(+), 15 deletions(-)

diff --git a/superset-frontend/src/SqlLab/components/AceEditorWrapper/index.tsx b/superset-frontend/src/SqlLab/components/AceEditorWrapper/index.tsx
index 0dd3385ea5..412fe48333 100644
--- a/superset-frontend/src/SqlLab/components/AceEditorWrapper/index.tsx
+++ b/superset-frontend/src/SqlLab/components/AceEditorWrapper/index.tsx
@@ -57,6 +57,7 @@ type AceEditorWrapperProps = {
   extendedTables?: Array<{ name: string; columns: any[] }>;
   height: string;
   hotkeys: HotKey[];
+  initialSql?: string;
 };
 
 const StyledAceEditor = styled(AceEditor)`
@@ -90,6 +91,7 @@ const AceEditorWrapper = ({
   extendedTables = [],
   height,
   hotkeys,
+  initialSql,
 }: AceEditorWrapperProps) => {
   const dispatch = useDispatch();
 
@@ -103,7 +105,7 @@ const AceEditorWrapper = ({
     'validationResult',
     'schema',
   ]);
-  const currentSql = queryEditor.sql ?? '';
+  const currentSql = initialSql || queryEditor.sql || '';
   const functionNames = queryEditor.functionNames ?? [];
   const schemas = queryEditor.schemaOptions ?? [];
   const tables = queryEditor.tableOptions ?? [];
diff --git a/superset-frontend/src/SqlLab/components/SqlEditor/index.jsx b/superset-frontend/src/SqlLab/components/SqlEditor/index.jsx
index d7626c8cbf..d517b4e939 100644
--- a/superset-frontend/src/SqlLab/components/SqlEditor/index.jsx
+++ b/superset-frontend/src/SqlLab/components/SqlEditor/index.jsx
@@ -31,8 +31,8 @@ import Mousetrap from 'mousetrap';
 import Button from 'src/components/Button';
 import Timer from 'src/components/Timer';
 import ResizableSidebar from 'src/components/ResizableSidebar';
-import { AntdDropdown, AntdSwitch } from 'src/components';
-import { Input } from 'src/components/Input';
+import { AntdDropdown, AntdSwitch, Result } from 'src/components';
+import { Input, TextArea } from 'src/components/Input';
 import { Menu } from 'src/components/Menu';
 import Icons from 'src/components/Icons';
 import { detectOS } from 'src/utils/common';
@@ -86,6 +86,7 @@ import SqlEditorLeftBar from '../SqlEditorLeftBar';
 import AceEditorWrapper from '../AceEditorWrapper';
 import RunQueryActionButton from '../RunQueryActionButton';
 import QueryLimitSelect from '../QueryLimitSelect';
+import { SupersetClient } from '@superset-ui/core';
 
 const bootstrapData = getBootstrapData();
 const validatorMap =
@@ -222,10 +223,13 @@ const SqlEditor = ({
         database: databases[dbId],
         latestQuery: queries[latestQueryId],
         hideLeftBar,
+        tables,
       };
     },
   );
 
+  console.log('tables', tables);
+
   const [height, setHeight] = useState(0);
   const [autorun, setAutorun] = useState(queryEditor.autorun);
   const [ctas, setCtas] = useState('');
@@ -640,6 +644,112 @@ const SqlEditor = ({
     );
   };
 
+  // NLP implementation
+  const [editorType, setEditorType] = useState('sql');
+  const [NLPQuery, setNLPQuery] = useState('');
+  const [NLPResult, setNLPResult] = useState('');
+  const [NLPLoading, setNLPLoading] = useState(false);
+  const handleNLPGeneration = async () => {
+    setNLPLoading(true);
+    let tablesContext = "# Given the following table/s definition\n\n";
+    for(let t = 0; t < tables.length; t += 1) {
+      const table = tables[t];
+      if (table?.columns?.length) {
+        tablesContext += `# Table ${table.name}, columns = [`;
+        for(let c = 0; c < table.columns.length; c += 1) {
+          const col = table.columns[c];
+          tablesContext += col.name;
+          tablesContext += table.columns.at(-1)?.name === col.name ? "]" : ", "
+        }
+        tablesContext += `\n\n`;
+      }
+    }
+    tablesContext += `# Create one valid SQL SELECT statement with the following constraints\n`;
+    tablesContext += `# For example: SELECT column FROM table;\n`;
+    tablesContext += `# Do NOT generate more tha one SELECT statement\n`;
+    tablesContext += `# Do NOT generate any text other than one valid SELECT statement\n`;
+    tablesContext += `# Do ONLY use SELECT\n`;
+    tablesContext += `# Respond with a SQL statement to select ${NLPQuery} from the given tables ->`;
+    const postPayload = {
+      prompt: tablesContext,
+    }
+    SupersetClient.post({
+      endpoint: "api/v1/sqllab/nlp",
+      headers: { 'Content-Type': 'application/json' },
+      body: JSON.stringify(postPayload),
+      parseMethod: 'json-bigint',
+    })
+      .then(({ json }) => {
+        setNLPResult(json.result);
+        setEditorType('sql')
+        setNLPLoading(false);
+      })
+      .catch(() => {
+        setNLPLoading(false);
+      });
+
+    console.log(tablesContext);
+  }
+  const renderNLPMenu = (
+    <Menu
+      mode="horizontal"
+      defaultSelectedKeys={[editorType]}
+      css={css`
+        margin-bottom: 20px;
+      `}
+    >
+      <Menu.Item
+        disabled={NLPLoading}
+        key="sql"
+        icon={<Icons.ConsoleSqlOutlined />}
+        onClick={({ key }) => setEditorType(key)}
+      >
+        SQL
+      </Menu.Item>
+      <Menu.Item
+        disabled={NLPLoading}
+        key="nlp"
+        icon={<Icons.ExperimentOutlined />}
+        onClick={({ key }) => setEditorType(key)}
+      >
+        Natural language
+      </Menu.Item>
+    </Menu>
+  );
+
+  const renderNLPBottomBar = (
+    tables.length > 0 ? <Button
+      type="primary"
+      size="large"
+      onClick={handleNLPGeneration}
+      css={css`
+        margin-top: 20px;
+        margin-bottom: 20px;
+        padding: 20px;
+        width: 250px;
+      `}
+    >
+      {!NLPLoading ? <Icons.ExperimentOutlined /> : <Icons.LoadingOutlined />}{' '}
+      Generate query
+    </Button> : null
+  );
+
+  const renderNLPForm =
+    tables.length > 0 ? (
+      <TextArea
+        disabled={NLPLoading}
+        rows={6}
+        onChange={e => setNLPQuery(e.target.value)}
+        placeholder="Get all fruits from the tree..."
+      />
+    ) : (
+      <Result
+        status="warning"
+        title={<small>Please select one or more tables</small>}
+        subTitle="Select one or more tables on the left pane for the AI algorithm to gain context about your query"
+      />
+    );
+
   const queryPane = () => {
     const hotkeys = getHotkeyConfig();
     const { aceEditorHeight, southPaneHeight } =
@@ -657,17 +767,22 @@ const SqlEditor = ({
         onDragEnd={onResizeEnd}
       >
         <div ref={northPaneRef} className="north-pane">
-          <AceEditorWrapper
-            autocomplete={autocompleteEnabled}
-            onBlur={setQueryEditorAndSaveSql}
-            onChange={onSqlChanged}
-            queryEditorId={queryEditor.id}
-            database={database}
-            extendedTables={tables}
-            height={`${aceEditorHeight}px`}
-            hotkeys={hotkeys}
-          />
-          {renderEditorBottomBar(hotkeys)}
+          {renderNLPMenu}
+          {editorType === 'sql' && (
+            <AceEditorWrapper
+              autocomplete={autocompleteEnabled}
+              onBlur={setQueryEditorAndSaveSql}
+              onChange={onSqlChanged}
+              queryEditorId={queryEditor.id}
+              database={database}
+              extendedTables={tables}
+              height={`${aceEditorHeight}px`}
+              hotkeys={hotkeys}
+              initialSql={NLPResult}
+            />
+          )}
+          {editorType === 'nlp' && renderNLPForm}
+          {editorType === 'sql' ? renderEditorBottomBar(hotkeys) : renderNLPBottomBar}
         </div>
         <SouthPane
           queryEditorId={queryEditor.id}
diff --git a/superset-frontend/src/components/index.ts b/superset-frontend/src/components/index.ts
index bfa341a9dd..84b10cc979 100644
--- a/superset-frontend/src/components/index.ts
+++ b/superset-frontend/src/components/index.ts
@@ -46,6 +46,7 @@ export {
   Tree,
   Typography,
   Upload,
+  Result,
 } from 'antd';
 
 /*
diff --git a/superset/sqllab/api.py b/superset/sqllab/api.py
index 283c3ab638..dd659d753c 100644
--- a/superset/sqllab/api.py
+++ b/superset/sqllab/api.py
@@ -41,6 +41,7 @@ from superset.sqllab.execution_context_convertor import ExecutionContextConverto
 from superset.sqllab.query_render import SqlQueryRenderImpl
 from superset.sqllab.schemas import (
     ExecutePayloadSchema,
+    NLPPayloadSchema,
     QueryExecutionResponseSchema,
     sql_lab_get_results_schema,
 )
@@ -55,6 +56,7 @@ from superset.superset_typing import FlaskResponse
 from superset.utils import core as utils
 from superset.views.base import json_success
 from superset.views.base_api import BaseSupersetApi, requires_json, statsd_metrics
+import openai
 
 config = app.config
 logger = logging.getLogger(__name__)
@@ -75,10 +77,69 @@ class SqlLabRestApi(BaseSupersetApi):
     }
     openapi_spec_tag = "SQL Lab"
     openapi_spec_component_schemas = (
+        NLPPayloadSchema,
         ExecutePayloadSchema,
         QueryExecutionResponseSchema,
     )
 
+    @expose("/nlp/", methods=["POST"])
+    # @protect()
+    # @statsd_metrics
+    # @requires_json
+    def execute_completion(self) -> FlaskResponse:
+        """Executes a SQL query
+        ---
+        post:
+          description: >-
+            Starts the execution of a SQL query
+          requestBody:
+            description: SQL query and params
+            required: true
+            content:
+              application/json:
+                schema:
+                  $ref: '#/components/schemas/NLPPayloadSchema'
+          responses:
+            200:
+              description: Query execution result
+              content:
+                application/json:
+                  schema:
+                    $ref: '#/components/schemas/NLPPayloadSchema'
+            202:
+              description: Query execution result, query still running
+              content:
+                application/json:
+                  schema:
+                    $ref: '#/components/schemas/NLPPayloadSchema'
+            400:
+              $ref: '#/components/responses/400'
+            401:
+              $ref: '#/components/responses/401'
+            403:
+              $ref: '#/components/responses/403'
+            404:
+              $ref: '#/components/responses/404'
+            500:
+              $ref: '#/components/responses/500'
+        """
+        try:
+            openai.api_key = "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"
+            requestPrompt = request.json
+            completion = openai.Completion.create(
+                engine="code-davinci-002",
+                prompt=requestPrompt['prompt'],
+                max_tokens=100,
+                temperature=0
+            )
+            choice = {**completion.choices[0]}
+            payload = {
+                'result': choice['text'],
+            }
+            return self.response(200, **payload)
+        except Exception as e:
+            print(e);
+
     @expose("/results/")
     @protect()
     @statsd_metrics
diff --git a/superset/sqllab/schemas.py b/superset/sqllab/schemas.py
index f238fda5c9..0ca760b757 100644
--- a/superset/sqllab/schemas.py
+++ b/superset/sqllab/schemas.py
@@ -23,7 +23,8 @@ sql_lab_get_results_schema = {
     },
     "required": ["key"],
 }
-
+class NLPPayloadSchema(Schema):
+    prompt = fields.String(required=True)
 
 class ExecutePayloadSchema(Schema):
     database_id = fields.Integer(required=True)