You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by dp...@apache.org on 2020/02/20 10:15:44 UTC

[incubator-superset] branch master updated: [database] Fix, tables API endpoint (#9144)

This is an automated email from the ASF dual-hosted git repository.

dpgaspar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new e55fe43  [database] Fix, tables API endpoint (#9144)
e55fe43 is described below

commit e55fe43ca67a29518674a1a2137a3dbd4f166864
Author: Daniel Vaz Gaspar <da...@gmail.com>
AuthorDate: Thu Feb 20 10:15:22 2020 +0000

    [database] Fix, tables API endpoint (#9144)
---
 superset/views/core.py         | 71 ++++++++++++++++++++----------------------
 superset/views/database/api.py |  2 ++
 tests/base_tests.py            |  7 +++++
 tests/core_tests.py            | 37 ++++++++++++++++++++++
 4 files changed, 80 insertions(+), 37 deletions(-)

diff --git a/superset/views/core.py b/superset/views/core.py
index f94a906..1c4640c 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -27,17 +27,7 @@ import msgpack
 import pandas as pd
 import pyarrow as pa
 import simplejson as json
-from flask import (
-    abort,
-    flash,
-    g,
-    Markup,
-    redirect,
-    render_template,
-    request,
-    Response,
-    url_for,
-)
+from flask import abort, flash, g, Markup, redirect, render_template, request, Response
 from flask_appbuilder import expose
 from flask_appbuilder.models.sqla.interface import SQLAInterface
 from flask_appbuilder.security.decorators import has_access, has_access_api
@@ -46,7 +36,6 @@ from flask_babel import gettext as __, lazy_gettext as _
 from sqlalchemy import and_, Integer, or_, select
 from sqlalchemy.exc import SQLAlchemyError
 from sqlalchemy.orm.session import Session
-from werkzeug.routing import BaseConverter
 from werkzeug.urls import Href
 
 import superset.models.core as models
@@ -88,11 +77,10 @@ from superset.sql_validators import get_validator_by_name
 from superset.utils import core as utils, dashboard_import_export
 from superset.utils.dates import now_as_float
 from superset.utils.decorators import etag_cache, stats_timing
-from superset.views.chart import views as chart_views
+from superset.views.database.filters import DatabaseFilter
 
 from .base import (
     api,
-    BaseFilter,
     BaseSupersetView,
     check_ownership,
     common_bootstrap_payload,
@@ -107,8 +95,6 @@ from .base import (
     json_success,
     SupersetModelView,
 )
-from .dashboard import views as dash_views
-from .database import views as in_views
 from .utils import (
     apply_display_max_row_limit,
     bootstrap_user_data,
@@ -1068,21 +1054,30 @@ class Superset(BaseSupersetView):
 
     @api
     @has_access_api
-    @expose("/tables/<db_id>/<schema>/<substr>/")
-    @expose("/tables/<db_id>/<schema>/<substr>/<force_refresh>/")
-    def tables(self, db_id, schema, substr, force_refresh="false"):
+    @expose("/tables/<int:db_id>/<schema>/<substr>/")
+    @expose("/tables/<int:db_id>/<schema>/<substr>/<force_refresh>/")
+    def tables(
+        self, db_id: int, schema: str, substr: str, force_refresh: str = "false"
+    ):
         """Endpoint to fetch the list of tables for given database"""
-        db_id = int(db_id)
-        force_refresh = force_refresh.lower() == "true"
-        schema = utils.parse_js_uri_path_item(schema, eval_undefined=True)
-        substr = utils.parse_js_uri_path_item(substr, eval_undefined=True)
-        database = db.session.query(models.Database).filter_by(id=db_id).one()
+        # Guarantees database filtering by security access
+        query = db.session.query(models.Database)
+        query = DatabaseFilter("id", SQLAInterface(models.Database, db.session)).apply(
+            query, None
+        )
+        database = query.filter_by(id=db_id).one_or_none()
+        if not database:
+            return json_error_response("Not found", 404)
 
-        if schema:
+        force_refresh_parsed = force_refresh.lower() == "true"
+        schema_parsed = utils.parse_js_uri_path_item(schema, eval_undefined=True)
+        substr_parsed = utils.parse_js_uri_path_item(substr, eval_undefined=True)
+
+        if schema_parsed:
             tables = (
                 database.get_all_table_names_in_schema(
-                    schema=schema,
-                    force=force_refresh,
+                    schema=schema_parsed,
+                    force=force_refresh_parsed,
                     cache=database.table_cache_enabled,
                     cache_timeout=database.table_cache_timeout,
                 )
@@ -1090,8 +1085,8 @@ class Superset(BaseSupersetView):
             )
             views = (
                 database.get_all_view_names_in_schema(
-                    schema=schema,
-                    force=force_refresh,
+                    schema=schema_parsed,
+                    force=force_refresh_parsed,
                     cache=database.table_cache_enabled,
                     cache_timeout=database.table_cache_timeout,
                 )
@@ -1105,20 +1100,22 @@ class Superset(BaseSupersetView):
                 cache=True, force=False, cache_timeout=24 * 60 * 60
             )
         tables = security_manager.get_datasources_accessible_by_user(
-            database, tables, schema
+            database, tables, schema_parsed
         )
         views = security_manager.get_datasources_accessible_by_user(
-            database, views, schema
+            database, views, schema_parsed
         )
 
         def get_datasource_label(ds_name: utils.DatasourceName) -> str:
-            return ds_name.table if schema else f"{ds_name.schema}.{ds_name.table}"
+            return (
+                ds_name.table if schema_parsed else f"{ds_name.schema}.{ds_name.table}"
+            )
 
-        if substr:
-            tables = [tn for tn in tables if substr in get_datasource_label(tn)]
-            views = [vn for vn in views if substr in get_datasource_label(vn)]
+        if substr_parsed:
+            tables = [tn for tn in tables if substr_parsed in get_datasource_label(tn)]
+            views = [vn for vn in views if substr_parsed in get_datasource_label(vn)]
 
-        if not schema and database.default_schemas:
+        if not schema_parsed and database.default_schemas:
             user_schema = g.user.email.split("@")[0]
             valid_schemas = set(database.default_schemas + [user_schema])
 
@@ -1129,7 +1126,7 @@ class Superset(BaseSupersetView):
         total_items = len(tables) + len(views)
         max_tables = len(tables)
         max_views = len(views)
-        if total_items and substr:
+        if total_items and substr_parsed:
             max_tables = max_items * len(tables) // total_items
             max_views = max_items * len(views) // total_items
 
diff --git a/superset/views/database/api.py b/superset/views/database/api.py
index 2eb5ad8..d4ee3c9 100644
--- a/superset/views/database/api.py
+++ b/superset/views/database/api.py
@@ -143,6 +143,8 @@ class DatabaseRestApi(DatabaseMixin, BaseSupersetModelRestApi):
     max_page_size = -1
     validators_columns = {"sqlalchemy_uri": sqlalchemy_uri_validator}
 
+    openapi_spec_tag = "Database"
+
     @expose(
         "/<int:pk>/table/<string:table_name>/<string:schema_name>/", methods=["GET"]
     )
diff --git a/tests/base_tests.py b/tests/base_tests.py
index 94f0ceb..5728dd8 100644
--- a/tests/base_tests.py
+++ b/tests/base_tests.py
@@ -40,6 +40,13 @@ FAKE_DB_NAME = "fake_db_100"
 
 
 class SupersetTestCase(TestCase):
+
+    default_schema_backend_map = {
+        "sqlite": "main",
+        "mysql": "superset",
+        "postgresql": "public",
+    }
+
     def __init__(self, *args, **kwargs):
         super(SupersetTestCase, self).__init__(*args, **kwargs)
         self.maxDiff = None
diff --git a/tests/core_tests.py b/tests/core_tests.py
index 19a5ccd..cf93d7a 100644
--- a/tests/core_tests.py
+++ b/tests/core_tests.py
@@ -165,6 +165,43 @@ class CoreTests(SupersetTestCase):
         # the new cache_key should be different due to updated datasource
         self.assertNotEqual(cache_key_original, cache_key_new)
 
+    def test_get_superset_tables_not_allowed(self):
+        example_db = utils.get_example_database()
+        schema_name = self.default_schema_backend_map[example_db.backend]
+        self.login(username="gamma")
+        uri = f"superset/tables/{example_db.id}/{schema_name}/undefined/"
+        rv = self.client.get(uri)
+        self.assertEqual(rv.status_code, 404)
+
+    def test_get_superset_tables_substr(self):
+        example_db = utils.get_example_database()
+        self.login(username="admin")
+        schema_name = self.default_schema_backend_map[example_db.backend]
+        uri = f"superset/tables/{example_db.id}/{schema_name}/ab_role/"
+        rv = self.client.get(uri)
+        response = json.loads(rv.data.decode("utf-8"))
+        self.assertEqual(rv.status_code, 200)
+
+        expeted_response = {
+            "options": [
+                {
+                    "label": "ab_role",
+                    "schema": schema_name,
+                    "title": "ab_role",
+                    "type": "table",
+                    "value": "ab_role",
+                }
+            ],
+            "tableLength": 1,
+        }
+        self.assertEqual(response, expeted_response)
+
+    def test_get_superset_tables_not_found(self):
+        self.login(username="admin")
+        uri = f"superset/tables/invalid/public/undefined/"
+        rv = self.client.get(uri)
+        self.assertEqual(rv.status_code, 404)
+
     def test_api_v1_query_endpoint(self):
         self.login(username="admin")
         qc_dict = self._get_query_context_dict()