You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by be...@apache.org on 2023/03/15 18:34:39 UTC

[superset] 01/01: fix: improve schema security

This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch default_schema
in repository https://gitbox.apache.org/repos/asf/superset.git

commit dfb3904be8654b7dbb4a77caafbfb2202b2c58bd
Author: Beto Dealmeida <ro...@dealmeida.net>
AuthorDate: Wed Mar 15 11:34:11 2023 -0700

    fix: improve schema security
---
 superset/db_engine_specs/base.py          | 73 +++++++++++++++++++++++++++++--
 superset/db_engine_specs/drill.py         | 13 +++++-
 superset/db_engine_specs/hive.py          | 13 +++++-
 superset/db_engine_specs/mysql.py         | 19 +++++++-
 superset/db_engine_specs/postgres.py      | 29 ++++++++++++
 superset/db_engine_specs/presto.py        | 23 +++++++++-
 superset/db_engine_specs/snowflake.py     | 18 +++++++-
 superset/models/core.py                   | 17 +++++++
 superset/security/manager.py              | 13 +-----
 tests/unit_tests/security/manager_test.py |  3 +-
 10 files changed, 198 insertions(+), 23 deletions(-)

diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 9d64ad8fb3..36eb06a8fa 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -372,7 +372,7 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
 
     # Is the DB engine spec able to change the default schema? This requires implementing
     # a custom `adjust_database_uri` method.
-    dynamic_schema = False
+    supports_dynamic_schema = False
 
     @classmethod
     def supports_url(cls, url: URL) -> bool:
@@ -426,6 +426,73 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
 
         return driver in cls.drivers
 
+    @classmethod
+    def get_default_schema(cls, database: Database) -> Optional[str]:
+        """
+        Return the default schema in a given database.
+
+        Some SQLAlchemy dialects can connect directly to a schema (eg, MySQL, though it
+        calls it a "database"), so they should return that as the default schema. For
+        other dialects like Postgres, where the default schema it not part of the URI,
+        we need to fetch the information from the DB.
+        """
+        with database.get_inspector_with_context() as inspector:
+            return inspector.default_schema_name
+
+    @classmethod
+    def get_schema_from_engine_params(
+        cls,
+        sqlalchemy_uri: URL,
+        connect_args: Dict[str, Any],
+    ) -> Optional[str]:
+        """
+        Return the schema configured in a SQLALchemy URI and connection argments, if any.
+        """
+        return None
+
+    @classmethod
+    def get_default_schema_for_query(
+        cls,
+        database: Database,
+        query: Query,
+    ) -> Optional[str]:
+        """
+        Return the default schema for a given query.
+
+        This is used to determine the schema of tables that aren't fully qualified, eg:
+
+            SELECT * FROM foo;
+
+        In the example above, the schema where the `foo` table lives depends on a few
+        factors:
+
+            1. For DB engine specs that allow dynamically changing the schema based on the
+               query we should use the query schema.
+            2. For DB engine specs that don't support dynamically changing the schema and
+               have the schema hardcoded in the SQLAlchemy URI we should use the schema
+               from the URI.
+            3. For DB engine specs that don't connect to a specific schema and can't
+               change it dynamically we need to probe the database for the default schema.
+
+        Determining the correct schema is crucial for managing access to data, so please
+        make sure you understand this logic when working on a new DB engine spec.
+        """
+        # default schema varies on a per-query basis
+        if self.supports_dynamic_schema:
+            return query.schema
+
+        # check if the schema is stored in the SQLAlchemy URI or connection arguments
+        try:
+            connect_args = database.get_extra()["engine_params"]["connect_args"]
+        except KeyError:
+            connect_args = {}
+        sqlalchemy_uri = make_url_safe(database.sqlalchemy_uri)
+        if schema := cls.get_schema_from_engine_params(sqlalchemy_uri, connect_args):
+            return schema
+
+        # return the default schema of the database
+        return self.get_default_schema(database)
+
     @classmethod
     def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]:
         """
@@ -1061,7 +1128,7 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
             raise cls.get_dbapi_mapped_exception(ex) from ex
 
         if schema and cls.try_remove_schema_from_table_name:
-            tables = {re.sub(f"^{schema}\\.", "", table) for table in tables}
+            tables = {table[len(schema) + 1 :] for table in tables}
         return tables
 
     @classmethod
@@ -1089,7 +1156,7 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
             raise cls.get_dbapi_mapped_exception(ex) from ex
 
         if schema and cls.try_remove_schema_from_table_name:
-            views = {re.sub(f"^{schema}\\.", "", view) for view in views}
+            view = {view[len(schema) + 1 :] for view in views}
         return views
 
     @classmethod
diff --git a/superset/db_engine_specs/drill.py b/superset/db_engine_specs/drill.py
index f14bdd79f0..a1ec2d0312 100644
--- a/superset/db_engine_specs/drill.py
+++ b/superset/db_engine_specs/drill.py
@@ -32,7 +32,7 @@ class DrillEngineSpec(BaseEngineSpec):
     engine_name = "Apache Drill"
     default_driver = "sadrill"
 
-    dynamic_schema = True
+    supports_dynamic_schema = True
 
     _time_grain_expressions = {
         None: "{col}",
@@ -77,6 +77,17 @@ class DrillEngineSpec(BaseEngineSpec):
 
         return uri
 
+    @classmethod
+    def get_schema_from_engine_params(
+        cls,
+        sqlalchemy_uri: URL,
+        connect_args: Dict[str, Any],
+    ) -> Optional[str]:
+        """
+        Return the configured schema.
+        """
+        return parse.unquote(sqlalchemy_uri.database)
+
     @classmethod
     def get_url_for_impersonation(
         cls, url: URL, impersonate_user: bool, username: Optional[str]
diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py
index 8c69ab3fc7..f90d889f8c 100644
--- a/superset/db_engine_specs/hive.py
+++ b/superset/db_engine_specs/hive.py
@@ -98,7 +98,7 @@ class HiveEngineSpec(PrestoEngineSpec):
     allows_alias_to_source_column = True
     allows_hidden_orderby_agg = False
 
-    dynamic_schema = True
+    supports_dynamic_schema = True
 
     # When running `SHOW FUNCTIONS`, what is the name of the column with the
     # function names?
@@ -268,6 +268,17 @@ class HiveEngineSpec(PrestoEngineSpec):
 
         return uri
 
+    @classmethod
+    def get_schema_from_engine_params(
+        cls,
+        sqlalchemy_uri: URL,
+        connect_args: Dict[str, Any],
+    ) -> Optional[str]:
+        """
+        Return the configured schema.
+        """
+        return parse.unquote(sqlalchemy_uri.database)
+
     @classmethod
     def _extract_error_message(cls, ex: Exception) -> str:
         msg = str(ex)
diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py
index 75c1c69789..04b8c68dd7 100644
--- a/superset/db_engine_specs/mysql.py
+++ b/superset/db_engine_specs/mysql.py
@@ -69,7 +69,7 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
     )
     encryption_parameters = {"ssl": "1"}
 
-    dynamic_schema = True
+    supports_dynamic_schema = True
 
     column_type_mappings = (
         (
@@ -192,13 +192,28 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
 
     @classmethod
     def adjust_database_uri(
-        cls, uri: URL, selected_schema: Optional[str] = None
+        cls,
+        uri: URL,
+        selected_schema: Optional[str] = None,
     ) -> URL:
         if selected_schema:
             uri = uri.set(database=parse.quote(selected_schema, safe=""))
 
         return uri
 
+    @classmethod
+    def get_schema_from_engine_params(
+        cls,
+        sqlalchemy_uri: URL,
+        connect_args: Dict[str, Any],
+    ) -> Optional[str]:
+        """
+        Return the configured schema.
+
+        A MySQL database is a SQLAlchemy schema.
+        """
+        return parse.unquote(sqlalchemy_uri.database)
+
     @classmethod
     def get_datatype(cls, type_code: Any) -> Optional[str]:
         if not cls.type_code_map:
diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py
index cbe00ea58d..7b52eabfaf 100644
--- a/superset/db_engine_specs/postgres.py
+++ b/superset/db_engine_specs/postgres.py
@@ -146,6 +146,35 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
         ),
     }
 
+    @classmethod
+    def get_schema_from_engine_params(
+        cls,
+        sqlalchemy_uri: URL,
+        connect_args: Dict[str, Any],
+    ) -> Optional[str]:
+        """
+        Return the configured schema.
+
+        While Postgres doesn't support connecting directly to a given schema, it allows
+        users to specify a "search path" that is used to resolve non-qualified table
+        names; this can be specified in the database ``connect_args``.
+
+        One important detail is that the search path can be a comma separated list of
+        schemas. While this is supported by the SQLAlchemy dialect, it shouldn't be used
+        in Superset because it breaks schema-level permissions, since it's impossible
+        to determine the schema for a non-qualified table in a query. In cases like
+        that we raise an exception.
+        """
+        options = re.split(r"-c\s?", connect_args.get("options", ""))
+        for option in options:
+            if "=" not in option:
+                continue
+            key, value = option.strip().split("=", 1)
+            if key.strip() == "search_path":
+                return value.strip()
+
+        return None
+
     @classmethod
     def fetch_data(
         cls, cursor: Any, limit: Optional[int] = None
diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py
index cda946ec4d..dd7bd88cdb 100644
--- a/superset/db_engine_specs/presto.py
+++ b/superset/db_engine_specs/presto.py
@@ -165,7 +165,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
     A base class that share common functions between Presto and Trino
     """
 
-    dynamic_schema = True
+    supports_dynamic_schema = True
 
     column_type_mappings = (
         (
@@ -315,6 +315,27 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
 
         return uri
 
+    @classmethod
+    def get_schema_from_engine_params(
+        cls,
+        sqlalchemy_uri: URL,
+        connect_args: Dict[str, Any],
+    ) -> Optional[str]:
+        """
+        Return the configured schema.
+
+        For Presto the SQLAlchemy URI looks like this:
+
+            presto://localhost:8080/hive[/default]
+
+        """
+        database = sqlalchemy_uri.database.strip("/")
+
+        if "/" not in database:
+            return None
+
+        return parse.unquote(database.split("/")[1])
+
     @classmethod
     def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]:
         """
diff --git a/superset/db_engine_specs/snowflake.py b/superset/db_engine_specs/snowflake.py
index 38addb6e35..ba15eea7fb 100644
--- a/superset/db_engine_specs/snowflake.py
+++ b/superset/db_engine_specs/snowflake.py
@@ -83,7 +83,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
     default_driver = "snowflake"
     sqlalchemy_uri_placeholder = "snowflake://"
 
-    dynamic_schema = True
+    supports_dynamic_schema = True
 
     _time_grain_expressions = {
         None: "{col}",
@@ -147,6 +147,22 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
 
         return uri
 
+    @classmethod
+    def get_schema_from_engine_params(
+        cls,
+        sqlalchemy_uri: URL,
+        connect_args: Dict[str, Any],
+    ) -> Optional[str]:
+        """
+        Return the configured schema.
+        """
+        database = sqlalchemy_uri.database.strip("/")
+
+        if "/" not in database:
+            return None
+
+        return parse.unquote(database.split("/")[1])
+
     @classmethod
     def epoch_to_dttm(cls) -> str:
         return "DATEADD(S, {col}, '1970-01-01')"
diff --git a/superset/models/core.py b/superset/models/core.py
index 9c67a2efa6..edf06fb269 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -78,6 +78,7 @@ logger = logging.getLogger(__name__)
 
 if TYPE_CHECKING:
     from superset.databases.ssh_tunnel.models import SSHTunnel
+    from superset.models.sql_lab import Query
 
 DB_CONNECTION_MUTATOR = config["DB_CONNECTION_MUTATOR"]
 
@@ -483,6 +484,22 @@ class Database(
             with closing(engine.raw_connection()) as conn:
                 yield conn
 
+    def get_default_schema_for_query(self, query: "Query"):
+        """
+        Return the default schema for a given query.
+
+        This is used to determine if the user has access to a query that reads from table
+        names without a specific schema, eg:
+
+            SELECT * FROM `foo`
+
+        The schema of the `foo` table depends on the DB engine spec. Some DB engine specs
+        can change the default schema on a per-query basis; in other DB engine specs the
+        default schema is defined in the SQLAlchemy URI; and in others the default schema
+        might be determined by the database itself (like `public` for Postgres).
+        """
+        return self.db_engine_spec.get_default_schema_for_query(self, query)
+
     @property
     def quote_identifier(self) -> Callable[[str], str]:
         """Add quotes to potential identifiter expressions if needed"""
diff --git a/superset/security/manager.py b/superset/security/manager.py
index 9c154c6498..b197f126d1 100644
--- a/superset/security/manager.py
+++ b/superset/security/manager.py
@@ -1823,18 +1823,7 @@ class SupersetSecurityManager(  # pylint: disable=too-many-public-methods
                 return
 
             if query:
-                # Some databases can change the default schema in which the query wil run,
-                # respecting the selection in SQL Lab. If that's the case, the query
-                # schema becomes the default one.
-                if database.db_engine_spec.dynamic_schema:
-                    default_schema = query.schema
-                # For other databases, the selected schema in SQL Lab is used only for
-                # table discovery and autocomplete. In this case we need to use the
-                # database default schema for tables that don't have an explicit schema.
-                else:
-                    with database.get_inspector_with_context() as inspector:
-                        default_schema = inspector.default_schema_name
-
+                default_schema = database.get_default_schema_for_query(query)
                 tables = {
                     Table(table_.table, table_.schema or default_schema)
                     for table_ in sql_parse.ParsedQuery(query.sql).tables
diff --git a/tests/unit_tests/security/manager_test.py b/tests/unit_tests/security/manager_test.py
index 6d0468c75c..1843e7261c 100644
--- a/tests/unit_tests/security/manager_test.py
+++ b/tests/unit_tests/security/manager_test.py
@@ -52,8 +52,7 @@ def test_raise_for_access_query_default_schema(
     SqlaTable.query_datasources_by_name.return_value = []
 
     database = mocker.MagicMock()
-    database.db_engine_spec.dynamic_schema = False
-    database.get_inspector_with_context().__enter__().default_schema_name = "public"
+    database.get_default_schema_for_query.return_value = "public"
     query = mocker.MagicMock()
     query.database = database
     query.sql = "SELECT * FROM ab_user"