You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by di...@apache.org on 2023/04/04 10:36:30 UTC
[superset] 03/03: Use cosine similarity
This is an automated email from the ASF dual-hosted git repository.
diegopucci pushed a commit to branch feat/sqllab-natural-language
in repository https://gitbox.apache.org/repos/asf/superset.git
commit 184c6277f62f6c0b81c34c79615fa02ad9d4e3e2
Author: geido <di...@gmail.com>
AuthorDate: Tue Apr 4 12:35:59 2023 +0200
Use cosine similarity
---
.../src/SqlLab/components/SqlEditor/index.jsx | 51 ++++--------
superset/config.py | 6 ++
superset/sqllab/api.py | 94 ++++++++++++++++------
superset/sqllab/schemas.py | 6 +-
superset/views/datasource/views.py | 71 +++++++++++++++-
5 files changed, 163 insertions(+), 65 deletions(-)
diff --git a/superset-frontend/src/SqlLab/components/SqlEditor/index.jsx b/superset-frontend/src/SqlLab/components/SqlEditor/index.jsx
index b1f8c0532c..823d9c1184 100644
--- a/superset-frontend/src/SqlLab/components/SqlEditor/index.jsx
+++ b/superset-frontend/src/SqlLab/components/SqlEditor/index.jsx
@@ -228,8 +228,6 @@ const SqlEditor = ({
},
);
- console.log('tables', tables);
-
const [height, setHeight] = useState(0);
const [autorun, setAutorun] = useState(queryEditor.autorun);
const [ctas, setCtas] = useState('');
@@ -651,44 +649,32 @@ const SqlEditor = ({
const [NLPLoading, setNLPLoading] = useState(false);
const handleNLPGeneration = async () => {
setNLPLoading(true);
- let tablesContext = "";
- for(let t = 0; t < tables.length; t += 1) {
- const table = tables[t];
- if (table?.columns?.length) {
- tablesContext += `# Table ${table.name}, columns = [`;
- for(let c = 0; c < table.columns.length; c += 1) {
- const col = table.columns[c];
- tablesContext += col.name;
- tablesContext += table.columns.at(-1)?.name === col.name ? "]" : ", "
- }
- tablesContext += `\n\n`;
- }
- }
- tablesContext += `# Create a SQLite query to: ${NLPQuery}`;
const postPayload = {
- prompt: tablesContext,
+ to_sql: NLPQuery,
+ database_id: database.id,
+ database_backend: database.backend,
}
SupersetClient.post({
- endpoint: "api/v1/sqllab/nlp",
+ endpoint: "api/v1/sqllab/nlp/tosql",
headers: { 'Content-Type': 'application/json' },
- body: JSON.stringify(postPayload),
parseMethod: 'json-bigint',
+ body: JSON.stringify(postPayload),
})
.then(({ json }) => {
+ const sql = json.result.trim();
+ queryEditor.sql = sql;
setEditorType('sql');
- setNLPResult(json.result.trim());
+ setNLPResult(sql);
setNLPLoading(false);
})
.catch(() => {
setNLPLoading(false);
});
-
- console.log(tablesContext);
}
const renderNLPMenu = useMemo(() => (
<Menu
mode="horizontal"
- defaultSelectedKeys={[editorType]}
+ selectedKeys={[editorType]}
css={css`
margin-bottom: 20px;
`}
@@ -713,7 +699,7 @@ const SqlEditor = ({
), [NLPLoading, editorType]);
const renderNLPBottomBar = (
- tables.length > 0 ? <Button
+ <Button
type="primary"
size="large"
onClick={handleNLPGeneration}
@@ -726,24 +712,17 @@ const SqlEditor = ({
>
{!NLPLoading ? <Icons.ExperimentOutlined /> : <Icons.LoadingOutlined />}{' '}
Generate query
- </Button> : null
+ </Button>
);
- const renderNLPForm =
- tables.length > 0 ? (
- <TextArea
+ const renderNLPForm = (
+ <TextArea
disabled={NLPLoading}
rows={6}
onChange={e => setNLPQuery(e.target.value)}
placeholder="Select all names from table"
- />
- ) : (
- <Result
- status="warning"
- title={<small>Please select one or more tables</small>}
- subTitle="Select one or more tables on the left pane for the AI algorithm to gain context about your query"
- />
- );
+ />
+ );
const queryPane = () => {
const hotkeys = getHotkeyConfig();
diff --git a/superset/config.py b/superset/config.py
index 5a64c99f77..4bfb47d835 100644
--- a/superset/config.py
+++ b/superset/config.py
@@ -1583,3 +1583,9 @@ elif importlib.util.find_spec("superset_config") and not is_test():
except Exception:
logger.exception("Found but failed to import local superset_config")
raise
+
+# NLP
+
+OPENAI_API_KEY="INSERT_OPENAI_KEY_HERE"
+PINECONE_API_KEY="INSERT_PINECONE_KEY_HERE"
+PINECONE_INDEX_NAME="preset"
diff --git a/superset/sqllab/api.py b/superset/sqllab/api.py
index 2179cec159..3dceaa74be 100644
--- a/superset/sqllab/api.py
+++ b/superset/sqllab/api.py
@@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
import logging
+import openai
+import pinecone
from typing import Any, cast, Dict, Optional
import simplejson as json
@@ -41,7 +43,7 @@ from superset.sqllab.execution_context_convertor import ExecutionContextConverto
from superset.sqllab.query_render import SqlQueryRenderImpl
from superset.sqllab.schemas import (
ExecutePayloadSchema,
- NLPPayloadSchema,
+ NLPtoSQLPayloadSchema,
QueryExecutionResponseSchema,
sql_lab_get_results_schema,
)
@@ -56,7 +58,6 @@ from superset.superset_typing import FlaskResponse
from superset.utils import core as utils
from superset.views.base import json_success
from superset.views.base_api import BaseSupersetApi, requires_json, statsd_metrics
-import openai
config = app.config
logger = logging.getLogger(__name__)
@@ -71,47 +72,42 @@ class SqlLabRestApi(BaseSupersetApi):
class_permission_name = "Query"
execute_model_schema = ExecutePayloadSchema()
+ execute_nlp_to_sql_schema = NLPtoSQLPayloadSchema()
apispec_parameter_schemas = {
"sql_lab_get_results_schema": sql_lab_get_results_schema,
}
openapi_spec_tag = "SQL Lab"
openapi_spec_component_schemas = (
- NLPPayloadSchema,
+ NLPtoSQLPayloadSchema,
ExecutePayloadSchema,
QueryExecutionResponseSchema,
)
- @expose("/nlp/", methods=["POST"])
+ @expose("/nlp/tosql", methods=["POST"])
# @protect()
# @statsd_metrics
- # @requires_json
+ @requires_json
def execute_completion(self) -> FlaskResponse:
- """Executes a SQL query
+ """Translates natural language to SQL
---
post:
description: >-
- Starts the execution of a SQL query
+ Executes the translation to SQL
requestBody:
description: SQL query and params
required: true
content:
application/json:
schema:
- $ref: '#/components/schemas/NLPPayloadSchema'
+ $ref: '#/components/schemas/NLPtoSQLPayloadSchema'
responses:
200:
- description: Query execution result
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/NLPPayloadSchema'
- 202:
- description: Query execution result, query still running
+ description: Natural language to SQL result
content:
application/json:
schema:
- $ref: '#/components/schemas/NLPPayloadSchema'
+ $ref: '#/components/schemas/NLPtoSQLPayloadSchema'
400:
$ref: '#/components/responses/400'
401:
@@ -124,18 +120,64 @@ class SqlLabRestApi(BaseSupersetApi):
$ref: '#/components/responses/500'
"""
try:
- openai.api_key = "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"
- requestPrompt = request.json
- completion = openai.Completion.create(
- engine="text-davinci-003",
- prompt=requestPrompt['prompt'],
- max_tokens=100,
- temperature=0.1,
- stop="END"
+ self.execute_nlp_to_sql_schema.load(request.json)
+ except ValidationError as error:
+ return self.response_400(message=error.messages)
+ try:
+ openai.api_key = app.config["OPENAI_API_KEY"]
+ req_body = request.json
+ to_sql = req_body.get("to_sql")
+ database_id = req_body.get("database_id")
+ database_backend = req_body.get("database_backend")
+
+ pinecone.init(
+ api_key=app.config["PINECONE_API_KEY"],
+ environment="us-east1-gcp"
+ )
+
+ # attempts to get required datasources from vectors
+ pinecone_index_name = app.config["PINECONE_INDEX_NAME"]
+ pinecone_index = pinecone.Index(pinecone_index_name)
+ prompt_to_vectors_res = openai.Embedding.create(
+ input=[to_sql], engine="text-embedding-ada-002"
+ )
+ prompt_to_vectors = prompt_to_vectors_res['data'][0]['embedding']
+ pinecone_query = pinecone_index.query(
+ prompt_to_vectors,
+ top_k=2,
+ filter={
+ "database_id": database_id,
+ },
+ namespace="datasource",
+ include_metadata=True
+ )
+ pinecone_matches = pinecone_query.get('matches', [])
+ all_sources = ""
+ for obj in pinecone_matches:
+ all_sources += f"{obj['metadata']['original']}\n"
+
+ prompt = ""
+ prompt += f"\n{all_sources}\n"
+ prompt += f"\nCOMMAND:\n{to_sql}\n"
+ prompt += "\nINSTRUCTIONS:\n"
+ prompt += "The 'COMMAND' given above is a natural language command that you must transform into one SQL query.\n"
+ prompt += "In order to generate the SQL query properly, follow the instructions below:\n"
+ prompt += "1. Only SELECT statements are allowed\n"
+ prompt += f"2. Use the SQL dialect {database_backend}\n"
+ prompt += "3. Use all table definitions above\n"
+ prompt += "4. Respond solely with the SQL query\n"
+ prompt += "\n{{ SQL_QUERY }}\n"
+
+ chat_completion = openai.ChatCompletion.create(
+ model="gpt-3.5-turbo",
+ temperature=0,
+ messages=[
+ {"role": "user", "content": prompt},
+ ]
)
- choice = {**completion.choices[0]}
+ completion = {**chat_completion.choices[0]}
payload = {
- 'result': choice['text'],
+ 'result': completion['message']['content'],
}
return self.response(200, **payload)
except Exception as e:
diff --git a/superset/sqllab/schemas.py b/superset/sqllab/schemas.py
index 0ca760b757..d40bf50e38 100644
--- a/superset/sqllab/schemas.py
+++ b/superset/sqllab/schemas.py
@@ -23,8 +23,10 @@ sql_lab_get_results_schema = {
},
"required": ["key"],
}
-class NLPPayloadSchema(Schema):
- prompt = fields.String(required=True)
+class NLPtoSQLPayloadSchema(Schema):
+ to_sql = fields.String(required=True)
+ database_id = fields.Integer(required=True)
+ database_backend = fields.String(required=True)
class ExecutePayloadSchema(Schema):
database_id = fields.Integer(required=True)
diff --git a/superset/views/datasource/views.py b/superset/views/datasource/views.py
index 4f158e8369..40b6f33e1e 100644
--- a/superset/views/datasource/views.py
+++ b/superset/views/datasource/views.py
@@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
import json
+import pinecone
+import openai
from collections import Counter
from typing import Any
@@ -27,7 +29,7 @@ from marshmallow import ValidationError
from sqlalchemy.exc import NoSuchTableError
from sqlalchemy.orm.exc import NoResultFound
-from superset import db, event_logger, security_manager
+from superset import app, db, event_logger, security_manager
from superset.commands.utils import populate_owners
from superset.connectors.sqla.models import SqlaTable
from superset.connectors.sqla.utils import get_physical_table_metadata
@@ -125,6 +127,73 @@ class Datasource(BaseSupersetView):
)
orm_datasource.update_from_object(datasource_dict)
data = orm_datasource.data
+
+ # transform the datasource info to vectors
+ datasource_table = datasource_dict.get('table_name')
+ datasource_schema = datasource_dict.get('schema')
+ datasource_columns = datasource_dict.get('columns')
+ datasource_desc = datasource_dict.get('description')
+ datasource_sel_star= datasource_dict.get('select_star')
+ database_backend = datasource_dict['database'].get('backend')
+ database_name = datasource_dict['database'].get('name')
+ stringified_columns = ""
+ for obj in datasource_columns:
+ stringified_columns += f"col_name: {obj['column_name']},"
+ stringified_columns += f"col_label: {obj['verbose_name']},"
+ stringified_columns += f"col_type: {obj['type']},"
+ stringified_columns += f"col_desc: {obj['description']}\n"
+ to_vectors = ""
+ to_vectors += f"# TABLE:\n"
+ to_vectors += f"table_name: {datasource_table}\n"
+ to_vectors += f"table_schema: {datasource_schema}\n"
+ to_vectors += f"table_desc: {datasource_desc}\n"
+
+ to_vectors += f"COLUMNS:\n"
+ to_vectors += stringified_columns
+
+ to_vectors += f"EXAMPLE:\n"
+ to_vectors += f"{datasource_sel_star}"
+
+ openai.api_key = app.config["OPENAI_API_KEY"]
+ pinecone.init(
+ api_key=app.config["PINECONE_API_KEY"],
+ environment="us-east1-gcp"
+ )
+ pinecone_index_name = app.config["PINECONE_INDEX_NAME"]
+ if pinecone_index_name not in pinecone.list_indexes():
+ # if does not exist, create index
+ pinecone.create_index(
+ name=pinecone_index_name,
+ # dimension of OpenAI embeddings
+ dimension=1536,
+ # use cosine similarity
+ metric='cosine'
+ )
+ # connect to Pinecone index
+ pinecone_index = pinecone.Index(pinecone_index_name)
+ # create embeddings with OpenAI
+ to_vectors_res = openai.Embedding.create(
+ input=[to_vectors], engine="text-embedding-ada-002"
+ )
+ embeddings = to_vectors_res['data'][0]['embedding']
+ vector_id = f"datasource-{datasource_id}"
+ # upsert vectors for this datasource
+ pinecone_index.upsert(
+ vectors=[(
+ vector_id,
+ embeddings,
+ {
+ "original": to_vectors,
+ "datasource_id": datasource_id,
+ "datasource_schema": datasource_schema,
+ "database_id": database_id,
+ "database_name": database_name,
+ "database_backend": database_backend
+ }
+ )],
+ namespace="datasource"
+ )
+
db.session.commit()
return self.json_response(sanitize_datasource_data(data))