You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by di...@apache.org on 2023/04/04 10:36:30 UTC

[superset] 03/03: Use cosine similarity

This is an automated email from the ASF dual-hosted git repository.

diegopucci pushed a commit to branch feat/sqllab-natural-language
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 184c6277f62f6c0b81c34c79615fa02ad9d4e3e2
Author: geido <di...@gmail.com>
AuthorDate: Tue Apr 4 12:35:59 2023 +0200

    Use cosine similarity
---
 .../src/SqlLab/components/SqlEditor/index.jsx      | 51 ++++--------
 superset/config.py                                 |  6 ++
 superset/sqllab/api.py                             | 94 ++++++++++++++++------
 superset/sqllab/schemas.py                         |  6 +-
 superset/views/datasource/views.py                 | 71 +++++++++++++++-
 5 files changed, 163 insertions(+), 65 deletions(-)

diff --git a/superset-frontend/src/SqlLab/components/SqlEditor/index.jsx b/superset-frontend/src/SqlLab/components/SqlEditor/index.jsx
index b1f8c0532c..823d9c1184 100644
--- a/superset-frontend/src/SqlLab/components/SqlEditor/index.jsx
+++ b/superset-frontend/src/SqlLab/components/SqlEditor/index.jsx
@@ -228,8 +228,6 @@ const SqlEditor = ({
     },
   );
 
-  console.log('tables', tables);
-
   const [height, setHeight] = useState(0);
   const [autorun, setAutorun] = useState(queryEditor.autorun);
   const [ctas, setCtas] = useState('');
@@ -651,44 +649,32 @@ const SqlEditor = ({
   const [NLPLoading, setNLPLoading] = useState(false);
   const handleNLPGeneration = async () => {
     setNLPLoading(true);
-    let tablesContext = "";
-    for(let t = 0; t < tables.length; t += 1) {
-      const table = tables[t];
-      if (table?.columns?.length) {
-        tablesContext += `# Table ${table.name}, columns = [`;
-        for(let c = 0; c < table.columns.length; c += 1) {
-          const col = table.columns[c];
-          tablesContext += col.name;
-          tablesContext += table.columns.at(-1)?.name === col.name ? "]" : ", "
-        }
-        tablesContext += `\n\n`;
-      }
-    }
-    tablesContext += `# Create a SQLite query to: ${NLPQuery}`;
     const postPayload = {
-      prompt: tablesContext,
+      to_sql: NLPQuery,
+      database_id: database.id,
+      database_backend: database.backend,
     }
     SupersetClient.post({
-      endpoint: "api/v1/sqllab/nlp",
+      endpoint: "api/v1/sqllab/nlp/tosql",
       headers: { 'Content-Type': 'application/json' },
-      body: JSON.stringify(postPayload),
       parseMethod: 'json-bigint',
+      body: JSON.stringify(postPayload),
     })
       .then(({ json }) => {
+        const sql = json.result.trim();
+        queryEditor.sql = sql;
         setEditorType('sql');
-        setNLPResult(json.result.trim());
+        setNLPResult(sql);
         setNLPLoading(false);
       })
       .catch(() => {
         setNLPLoading(false);
       });
-
-    console.log(tablesContext);
   }
   const renderNLPMenu = useMemo(() => (
     <Menu
       mode="horizontal"
-      defaultSelectedKeys={[editorType]}
+      selectedKeys={[editorType]}
       css={css`
         margin-bottom: 20px;
       `}
@@ -713,7 +699,7 @@ const SqlEditor = ({
   ), [NLPLoading, editorType]);
 
   const renderNLPBottomBar = (
-    tables.length > 0 ? <Button
+    <Button
       type="primary"
       size="large"
       onClick={handleNLPGeneration}
@@ -726,24 +712,17 @@ const SqlEditor = ({
     >
       {!NLPLoading ? <Icons.ExperimentOutlined /> : <Icons.LoadingOutlined />}{' '}
       Generate query
-    </Button> : null
+    </Button>
   );
 
-  const renderNLPForm =
-    tables.length > 0 ? (
-      <TextArea
+  const renderNLPForm = (
+    <TextArea
         disabled={NLPLoading}
         rows={6}
         onChange={e => setNLPQuery(e.target.value)}
         placeholder="Select all names from table"
-      />
-    ) : (
-      <Result
-        status="warning"
-        title={<small>Please select one or more tables</small>}
-        subTitle="Select one or more tables on the left pane for the AI algorithm to gain context about your query"
-      />
-    );
+     />
+  );
 
   const queryPane = () => {
     const hotkeys = getHotkeyConfig();
diff --git a/superset/config.py b/superset/config.py
index 5a64c99f77..4bfb47d835 100644
--- a/superset/config.py
+++ b/superset/config.py
@@ -1583,3 +1583,9 @@ elif importlib.util.find_spec("superset_config") and not is_test():
     except Exception:
         logger.exception("Found but failed to import local superset_config")
         raise
+
+# NLP
+
+OPENAI_API_KEY="INSERT_OPENAI_KEY_HERE"
+PINECONE_API_KEY="INSERT_PINECONE_KEY_HERE"
+PINECONE_INDEX_NAME="preset"
diff --git a/superset/sqllab/api.py b/superset/sqllab/api.py
index 2179cec159..3dceaa74be 100644
--- a/superset/sqllab/api.py
+++ b/superset/sqllab/api.py
@@ -15,6 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 import logging
+import openai
+import pinecone
 from typing import Any, cast, Dict, Optional
 
 import simplejson as json
@@ -41,7 +43,7 @@ from superset.sqllab.execution_context_convertor import ExecutionContextConverto
 from superset.sqllab.query_render import SqlQueryRenderImpl
 from superset.sqllab.schemas import (
     ExecutePayloadSchema,
-    NLPPayloadSchema,
+    NLPtoSQLPayloadSchema,
     QueryExecutionResponseSchema,
     sql_lab_get_results_schema,
 )
@@ -56,7 +58,6 @@ from superset.superset_typing import FlaskResponse
 from superset.utils import core as utils
 from superset.views.base import json_success
 from superset.views.base_api import BaseSupersetApi, requires_json, statsd_metrics
-import openai
 
 config = app.config
 logger = logging.getLogger(__name__)
@@ -71,47 +72,42 @@ class SqlLabRestApi(BaseSupersetApi):
     class_permission_name = "Query"
 
     execute_model_schema = ExecutePayloadSchema()
+    execute_nlp_to_sql_schema = NLPtoSQLPayloadSchema()
 
     apispec_parameter_schemas = {
         "sql_lab_get_results_schema": sql_lab_get_results_schema,
     }
     openapi_spec_tag = "SQL Lab"
     openapi_spec_component_schemas = (
-        NLPPayloadSchema,
+        NLPtoSQLPayloadSchema,
         ExecutePayloadSchema,
         QueryExecutionResponseSchema,
     )
 
-    @expose("/nlp/", methods=["POST"])
+    @expose("/nlp/tosql", methods=["POST"])
     # @protect()
     # @statsd_metrics
-    # @requires_json
+    @requires_json
     def execute_completion(self) -> FlaskResponse:
-        """Executes a SQL query
+        """Translates natural language to SQL
         ---
         post:
           description: >-
-            Starts the execution of a SQL query
+            Executes the translation to SQL
           requestBody:
             description: SQL query and params
             required: true
             content:
               application/json:
                 schema:
-                  $ref: '#/components/schemas/NLPPayloadSchema'
+                  $ref: '#/components/schemas/NLPtoSQLPayloadSchema'
           responses:
             200:
-              description: Query execution result
-              content:
-                application/json:
-                  schema:
-                    $ref: '#/components/schemas/NLPPayloadSchema'
-            202:
-              description: Query execution result, query still running
+              description: Natural language to SQL result
               content:
                 application/json:
                   schema:
-                    $ref: '#/components/schemas/NLPPayloadSchema'
+                    $ref: '#/components/schemas/NLPtoSQLPayloadSchema'
             400:
               $ref: '#/components/responses/400'
             401:
@@ -124,18 +120,64 @@ class SqlLabRestApi(BaseSupersetApi):
               $ref: '#/components/responses/500'
         """
         try:
-            openai.api_key = "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"
-            requestPrompt = request.json
-            completion = openai.Completion.create(
-                engine="text-davinci-003",
-                prompt=requestPrompt['prompt'],
-                max_tokens=100,
-                temperature=0.1,
-                stop="END"
+            self.execute_nlp_to_sql_schema.load(request.json)
+        except ValidationError as error:
+            return self.response_400(message=error.messages)
+        try:
+            openai.api_key = app.config["OPENAI_API_KEY"]
+            req_body = request.json
+            to_sql = req_body.get("to_sql")
+            database_id = req_body.get("database_id")
+            database_backend = req_body.get("database_backend")
+
+            pinecone.init(
+                api_key=app.config["PINECONE_API_KEY"],
+                environment="us-east1-gcp"
+            )
+
+            # attempts to get required datasources from vectors
+            pinecone_index_name = app.config["PINECONE_INDEX_NAME"]
+            pinecone_index = pinecone.Index(pinecone_index_name)
+            prompt_to_vectors_res = openai.Embedding.create(
+                input=[to_sql], engine="text-embedding-ada-002"
+            )
+            prompt_to_vectors = prompt_to_vectors_res['data'][0]['embedding']
+            pinecone_query = pinecone_index.query(
+                prompt_to_vectors,
+                top_k=2,
+                filter={
+                    "database_id": database_id,
+                },
+                namespace="datasource",
+                include_metadata=True
+            )
+            pinecone_matches = pinecone_query.get('matches', [])
+            all_sources = ""
+            for obj in pinecone_matches:
+                all_sources += f"{obj['metadata']['original']}\n"
+
+            prompt = ""
+            prompt += f"\n{all_sources}\n"
+            prompt += f"\nCOMMAND:\n{to_sql}\n"
+            prompt += "\nINSTRUCTIONS:\n"
+            prompt += "The 'COMMAND' given above is a natural language command that you must transform into one SQL query.\n"
+            prompt += "In order to generate the SQL query properly, follow the instructions below:\n"
+            prompt += "1. Only SELECT statements are allowed\n"
+            prompt += f"2. Use the SQL dialect {database_backend}\n"
+            prompt += "3. Use all table definitions above\n"
+            prompt += "4. Respond solely with the SQL query\n"
+            prompt += "\n{{ SQL_QUERY }}\n"
+
+            chat_completion = openai.ChatCompletion.create(
+                model="gpt-3.5-turbo",
+                temperature=0,
+                messages=[
+                    {"role": "user", "content": prompt},
+                ]
             )
-            choice = {**completion.choices[0]}
+            completion = {**chat_completion.choices[0]}
             payload = {
-                'result': choice['text'],
+                'result': completion['message']['content'],
             }
             return self.response(200, **payload)
         except Exception as e:
diff --git a/superset/sqllab/schemas.py b/superset/sqllab/schemas.py
index 0ca760b757..d40bf50e38 100644
--- a/superset/sqllab/schemas.py
+++ b/superset/sqllab/schemas.py
@@ -23,8 +23,10 @@ sql_lab_get_results_schema = {
     },
     "required": ["key"],
 }
-class NLPPayloadSchema(Schema):
-    prompt = fields.String(required=True)
+class NLPtoSQLPayloadSchema(Schema):
+    to_sql = fields.String(required=True)
+    database_id = fields.Integer(required=True)
+    database_backend = fields.String(required=True)
 
 class ExecutePayloadSchema(Schema):
     database_id = fields.Integer(required=True)
diff --git a/superset/views/datasource/views.py b/superset/views/datasource/views.py
index 4f158e8369..40b6f33e1e 100644
--- a/superset/views/datasource/views.py
+++ b/superset/views/datasource/views.py
@@ -15,6 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 import json
+import pinecone
+import openai
 from collections import Counter
 from typing import Any
 
@@ -27,7 +29,7 @@ from marshmallow import ValidationError
 from sqlalchemy.exc import NoSuchTableError
 from sqlalchemy.orm.exc import NoResultFound
 
-from superset import db, event_logger, security_manager
+from superset import app, db, event_logger, security_manager
 from superset.commands.utils import populate_owners
 from superset.connectors.sqla.models import SqlaTable
 from superset.connectors.sqla.utils import get_physical_table_metadata
@@ -125,6 +127,73 @@ class Datasource(BaseSupersetView):
             )
         orm_datasource.update_from_object(datasource_dict)
         data = orm_datasource.data
+
+        # transform the datasource info to vectors
+        datasource_table = datasource_dict.get('table_name')
+        datasource_schema = datasource_dict.get('schema')
+        datasource_columns = datasource_dict.get('columns')
+        datasource_desc = datasource_dict.get('description')
+        datasource_sel_star= datasource_dict.get('select_star')
+        database_backend = datasource_dict['database'].get('backend')
+        database_name = datasource_dict['database'].get('name')
+        stringified_columns = ""
+        for obj in datasource_columns:
+            stringified_columns += f"col_name: {obj['column_name']},"
+            stringified_columns += f"col_label: {obj['verbose_name']},"
+            stringified_columns += f"col_type: {obj['type']},"
+            stringified_columns += f"col_desc: {obj['description']}\n"
+        to_vectors = ""
+        to_vectors += f"# TABLE:\n"
+        to_vectors += f"table_name: {datasource_table}\n"
+        to_vectors += f"table_schema: {datasource_schema}\n"
+        to_vectors += f"table_desc: {datasource_desc}\n"
+
+        to_vectors += f"COLUMNS:\n"
+        to_vectors += stringified_columns
+
+        to_vectors += f"EXAMPLE:\n"
+        to_vectors += f"{datasource_sel_star}"
+
+        openai.api_key = app.config["OPENAI_API_KEY"]
+        pinecone.init(
+            api_key=app.config["PINECONE_API_KEY"],
+            environment="us-east1-gcp"
+        )
+        pinecone_index_name = app.config["PINECONE_INDEX_NAME"]
+        if pinecone_index_name not in pinecone.list_indexes():
+            # if does not exist, create index
+            pinecone.create_index(
+                name=pinecone_index_name,
+                # dimension of OpenAI embeddings
+                dimension=1536,
+                # use cosine similarity
+                metric='cosine'
+            )
+        # connect to Pinecone index
+        pinecone_index = pinecone.Index(pinecone_index_name)
+        # create embeddings with OpenAI
+        to_vectors_res = openai.Embedding.create(
+            input=[to_vectors], engine="text-embedding-ada-002"
+        )
+        embeddings = to_vectors_res['data'][0]['embedding']
+        vector_id = f"datasource-{datasource_id}"
+        # upsert vectors for this datasource
+        pinecone_index.upsert(
+            vectors=[(
+                vector_id,
+                embeddings,
+                {
+                    "original": to_vectors,
+                    "datasource_id": datasource_id,
+                    "datasource_schema": datasource_schema,
+                    "database_id": database_id,
+                    "database_name": database_name,
+                    "database_backend": database_backend
+                }
+            )],
+            namespace="datasource"
+        )
+
         db.session.commit()
 
         return self.json_response(sanitize_datasource_data(data))