You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by ar...@apache.org on 2023/07/06 20:48:54 UTC

[superset] branch master updated: feat(database): Database Filtering via custom configuration (#24580)

This is an automated email from the ASF dual-hosted git repository.

arivero pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 6657353bca feat(database): Database Filtering via custom configuration (#24580)
6657353bca is described below

commit 6657353bcafbfd4dcbd6596bfb97f5ace179d7e4
Author: Antonio Rivero <38...@users.noreply.github.com>
AuthorDate: Thu Jul 6 16:48:46 2023 -0400

    feat(database): Database Filtering via custom configuration (#24580)
---
 superset/config.py                             | 20 ++++++
 superset/databases/filters.py                  | 15 +++-
 tests/integration_tests/databases/api_tests.py | 93 ++++++++++++++++++++++++
 tests/unit_tests/databases/api_test.py         | 99 ++++++++++++++++++++++++++
 4 files changed, 226 insertions(+), 1 deletion(-)

diff --git a/superset/config.py b/superset/config.py
index 1dcc9220dd..9c6adf599b 100644
--- a/superset/config.py
+++ b/superset/config.py
@@ -1594,6 +1594,26 @@ class ExtraRelatedQueryFilters(TypedDict, total=False):
 EXTRA_RELATED_QUERY_FILTERS: ExtraRelatedQueryFilters = {}
 
 
+# Extra dynamic query filters make it possible to limit which objects are shown
+# in the UI before any other filtering is applied. Useful for example when
+# considering to filter using Feature Flags along with regular role filters
+# that get applied by default in our base_filters.
+# For example, to only show a database starting with the letter "b"
+# in the "Database Connections" list, you could add the following in your config:
+# def initial_database_filter(query: Query, *args, *kwargs):
+#     from superset.models.core import Database
+#
+#     filter = Database.database_name.startswith('b')
+#     return query.filter(filter)
+#
+#  EXTRA_DYNAMIC_QUERY_FILTERS = {"database": initial_database_filter}
+class ExtraDynamicQueryFilters(TypedDict, total=False):
+    databases: Callable[[Query], Query]
+
+
+EXTRA_DYNAMIC_QUERY_FILTERS: ExtraDynamicQueryFilters = {}
+
+
 # -------------------------------------------------------------------
 # *                WARNING:  STOP EDITING  HERE                    *
 # -------------------------------------------------------------------
diff --git a/superset/databases/filters.py b/superset/databases/filters.py
index 2ca77b77d1..384a62c9d3 100644
--- a/superset/databases/filters.py
+++ b/superset/databases/filters.py
@@ -16,7 +16,7 @@
 # under the License.
 from typing import Any
 
-from flask import g
+from flask import current_app, g
 from flask_babel import lazy_gettext as _
 from sqlalchemy import or_
 from sqlalchemy.orm import Query
@@ -41,6 +41,19 @@ class DatabaseFilter(BaseFilter):  # pylint: disable=too-few-public-methods
     # TODO(bogdan): consider caching.
 
     def apply(self, query: Query, value: Any) -> Query:
+        """
+        Dynamic Filters need to be applied to the Query before we filter
+        databases with anything else. This way you can show/hide databases using
+        Feature Flags for example in conjuction with the regular role filtering.
+        If not, if an user has access to all Databases it would skip this dynamic
+        filtering.
+        """
+
+        if dynamic_filters := current_app.config["EXTRA_DYNAMIC_QUERY_FILTERS"]:
+            if dynamic_databases_filter := dynamic_filters.get("databases"):
+                query = dynamic_databases_filter(query)
+
+        # We can proceed with default filtering now
         if security_manager.can_access_all_databases():
             return query
         database_perms = security_manager.user_view_menu_names("database_access")
diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py
index 568ba05934..ebf94219c3 100644
--- a/tests/integration_tests/databases/api_tests.py
+++ b/tests/integration_tests/databases/api_tests.py
@@ -28,6 +28,8 @@ import prison
 import pytest
 import yaml
 
+from unittest.mock import Mock
+
 from sqlalchemy.engine.url import make_url
 from sqlalchemy.exc import DBAPIError
 from sqlalchemy.sql import func
@@ -3632,3 +3634,94 @@ class TestDatabaseApi(SupersetTestCase):
             return
         self.assertEqual(rv.status_code, 422)
         self.assertIn("Kaboom!", response["errors"][0]["message"])
+
+    def test_get_databases_with_extra_filters(self):
+        """
+        API: Test get database with extra query filter.
+        Here we are testing our default where all databases
+        must be returned if nothing is being set in the config.
+        Then, we're adding the patch for the config to add the filter function
+        and testing it's being applied.
+        """
+        self.login(username="admin")
+        extra = {
+            "metadata_params": {},
+            "engine_params": {},
+            "metadata_cache_timeout": {},
+            "schemas_allowed_for_file_upload": [],
+        }
+        example_db = get_example_database()
+
+        if example_db.backend == "sqlite":
+            return
+        # Create our two databases
+        database_data = {
+            "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
+            "configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
+            "server_cert": None,
+            "extra": json.dumps(extra),
+        }
+
+        uri = "api/v1/database/"
+        rv = self.client.post(
+            uri, json={**database_data, "database_name": "dyntest-create-database-1"}
+        )
+        first_response = json.loads(rv.data.decode("utf-8"))
+        self.assertEqual(rv.status_code, 201)
+
+        uri = "api/v1/database/"
+        rv = self.client.post(
+            uri, json={**database_data, "database_name": "create-database-2"}
+        )
+        second_response = json.loads(rv.data.decode("utf-8"))
+        self.assertEqual(rv.status_code, 201)
+
+        # The filter function
+        def _base_filter(query):
+            from superset.models.core import Database
+
+            return query.filter(Database.database_name.startswith("dyntest"))
+
+        # Create the Mock
+        base_filter_mock = Mock(side_effect=_base_filter)
+        dbs = db.session.query(Database).all()
+        expected_names = [db.database_name for db in dbs]
+        expected_names.sort()
+
+        uri = f"api/v1/database/"
+        # Get the list of databases without filter in the config
+        rv = self.client.get(uri)
+        data = json.loads(rv.data.decode("utf-8"))
+        # All databases must be returned if no filter is present
+        self.assertEqual(data["count"], len(dbs))
+        database_names = [item["database_name"] for item in data["result"]]
+        database_names.sort()
+        # All Databases because we are an admin
+        self.assertEqual(database_names, expected_names)
+        assert rv.status_code == 200
+        # Our filter function wasn't get called
+        base_filter_mock.assert_not_called()
+
+        # Now we patch the config to include our filter function
+        with patch.dict(
+            "superset.views.filters.current_app.config",
+            {"EXTRA_DYNAMIC_QUERY_FILTERS": {"databases": base_filter_mock}},
+        ):
+            uri = f"api/v1/database/"
+            rv = self.client.get(uri)
+            data = json.loads(rv.data.decode("utf-8"))
+            # Only one database start with dyntest
+            self.assertEqual(data["count"], 1)
+            database_names = [item["database_name"] for item in data["result"]]
+            # Only the database that starts with tests, even if we are an admin
+            self.assertEqual(database_names, ["dyntest-create-database-1"])
+            assert rv.status_code == 200
+            # The filter function is called now that it's defined in our config
+            base_filter_mock.assert_called()
+
+        # Cleanup
+        first_model = db.session.query(Database).get(first_response.get("id"))
+        second_model = db.session.query(Database).get(second_response.get("id"))
+        db.session.delete(first_model)
+        db.session.delete(second_model)
+        db.session.commit()
diff --git a/tests/unit_tests/databases/api_test.py b/tests/unit_tests/databases/api_test.py
index 24fde88369..899e2b0234 100644
--- a/tests/unit_tests/databases/api_test.py
+++ b/tests/unit_tests/databases/api_test.py
@@ -20,9 +20,11 @@
 import json
 from io import BytesIO
 from typing import Any
+from unittest.mock import Mock
 from uuid import UUID
 
 import pytest
+from flask import current_app
 from pytest_mock import MockFixture
 from sqlalchemy.orm.session import Session
 
@@ -495,3 +497,100 @@ def test_delete_ssh_tunnel_not_found(
 
         response_tunnel = DatabaseDAO.get_ssh_tunnel(2)
         assert response_tunnel is None
+
+
+def test_apply_dynamic_database_filter(
+    mocker: MockFixture,
+    app: Any,
+    session: Session,
+    client: Any,
+    full_api_access: None,
+) -> None:
+    """
+    Test that we can filter the list of databases.
+    First test the default behavior without a filter and then
+    defining a filter function and patching the config to get
+    the filtered results.
+    """
+    with app.app_context():
+        from superset.daos.database import DatabaseDAO
+        from superset.databases.api import DatabaseRestApi
+        from superset.databases.ssh_tunnel.models import SSHTunnel
+        from superset.models.core import Database
+
+        DatabaseRestApi.datamodel.session = session
+
+        # create table for databases
+        Database.metadata.create_all(session.get_bind())  # pylint: disable=no-member
+
+        # Create our First Database
+        database = Database(
+            database_name="first-database",
+            sqlalchemy_uri="gsheets://",
+            encrypted_extra=json.dumps(
+                {
+                    "metadata_params": {},
+                    "engine_params": {},
+                    "metadata_cache_timeout": {},
+                    "schemas_allowed_for_file_upload": [],
+                }
+            ),
+        )
+        session.add(database)
+        session.commit()
+
+        # Create our Second Database
+        database = Database(
+            database_name="second-database",
+            sqlalchemy_uri="gsheets://",
+            encrypted_extra=json.dumps(
+                {
+                    "metadata_params": {},
+                    "engine_params": {},
+                    "metadata_cache_timeout": {},
+                    "schemas_allowed_for_file_upload": [],
+                }
+            ),
+        )
+        session.add(database)
+        session.commit()
+
+        # mock the lookup so that we don't need to include the driver
+        mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
+        mocker.patch("superset.utils.log.DBEventLogger.log")
+        mocker.patch(
+            "superset.databases.ssh_tunnel.commands.delete.is_feature_enabled",
+            return_value=False,
+        )
+
+        def _base_filter(query):
+            from superset.models.core import Database
+
+            return query.filter(Database.database_name.startswith("second"))
+
+        # Create a mock object
+        base_filter_mock = Mock(side_effect=_base_filter)
+
+        # Get our recently created Databases
+        response_databases = DatabaseDAO.find_all()
+        assert response_databases
+        expected_db_names = ["first-database", "second-database"]
+        actual_db_names = [db.database_name for db in response_databases]
+        assert actual_db_names == expected_db_names
+
+        # Ensure that the filter has not been called because it's not in our config
+        assert base_filter_mock.call_count == 0
+
+        original_config = current_app.config.copy()
+        original_config["EXTRA_DYNAMIC_QUERY_FILTERS"] = {"databases": base_filter_mock}
+
+        mocker.patch("superset.views.filters.current_app.config", new=original_config)
+        # Get filtered list
+        response_databases = DatabaseDAO.find_all()
+        assert response_databases
+        expected_db_names = ["second-database"]
+        actual_db_names = [db.database_name for db in response_databases]
+        assert actual_db_names == expected_db_names
+
+        # Ensure that the filter has been called once
+        assert base_filter_mock.call_count == 1