You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by be...@apache.org on 2023/07/20 19:57:55 UTC
[superset] branch master updated: fix: `search_path` in RDS (#24739)
This is an automated email from the ASF dual-hosted git repository.
beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new 7675e0db10 fix: `search_path` in RDS (#24739)
7675e0db10 is described below
commit 7675e0db10f42dbb76f908e9bc70906da204c98d
Author: Beto Dealmeida <ro...@dealmeida.net>
AuthorDate: Thu Jul 20 12:57:48 2023 -0700
fix: `search_path` in RDS (#24739)
---
superset/db_engine_specs/base.py | 33 +++++++--
superset/db_engine_specs/postgres.py | 88 +++++++++++++++--------
superset/models/core.py | 83 +++++++++++----------
tests/unit_tests/db_engine_specs/test_postgres.py | 48 +++++++++----
tests/unit_tests/models/core_test.py | 18 +++++
5 files changed, 185 insertions(+), 85 deletions(-)
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 0d778de439..af24e54790 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -1082,22 +1082,45 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
For some databases (like MySQL, Presto, Snowflake) this requires modifying the
SQLAlchemy URI before creating the connection. For others (like Postgres), it
- requires additional parameters in ``connect_args``.
+ requires additional parameters in ``connect_args`` or running pre-session
+ queries with ``set`` parameters.
- When a DB engine spec implements this method it should also have the attribute
- ``supports_dynamic_schema`` set to true, so that Superset knows in which schema a
- given query is running in order to enforce permissions (see #23385 and #23401).
+ When a DB engine spec implements this method or ``get_prequeries`` (see below) it
+ should also have the attribute ``supports_dynamic_schema`` set to true, so that
+ Superset knows in which schema a given query is running in order to enforce
+ permissions (see #23385 and #23401).
Currently, changing the catalog is not supported. The method accepts a catalog so
that when catalog support is added to Superset the interface remains the same.
This is important because DB engine specs can be installed from 3rd party
- packages.
+ packages, so we want to keep these methods as stable as possible.
"""
return uri, {
**connect_args,
**cls.enforce_uri_query_params.get(uri.get_driver_name(), {}),
}
+ @classmethod
+ def get_prequeries(
+ cls,
+ catalog: str | None = None, # pylint: disable=unused-argument
+ schema: str | None = None, # pylint: disable=unused-argument
+ ) -> list[str]:
+ """
+ Return pre-session queries.
+
+ These are currently used as an alternative to ``adjust_engine_params`` for
+ databases where the selected schema cannot be specified in the SQLAlchemy URI or
+ connection arguments.
+
+ For example, in order to specify a default schema in RDS we need to run a query
+ at the beggining of the session:
+
+ sql> set search_path = my_schema;
+
+ """
+ return []
+
@classmethod
def patch(cls) -> None:
"""
diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py
index cdd71fdfcc..642f84f58c 100644
--- a/superset/db_engine_specs/postgres.py
+++ b/superset/db_engine_specs/postgres.py
@@ -14,13 +14,17 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+
+from __future__ import annotations
+
import json
import logging
import re
from datetime import datetime
from re import Pattern
-from typing import Any, Optional, TYPE_CHECKING
+from typing import Any, TYPE_CHECKING
+import sqlparse
from flask_babel import gettext as __
from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION, ENUM, JSON
from sqlalchemy.dialects.postgresql.base import PGInspector
@@ -30,8 +34,8 @@ from sqlalchemy.types import Date, DateTime, String
from superset.constants import TimeGrain
from superset.db_engine_specs.base import BaseEngineSpec, BasicParametersMixin
-from superset.errors import SupersetErrorType
-from superset.exceptions import SupersetException
+from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
+from superset.exceptions import SupersetException, SupersetSecurityException
from superset.models.sql_lab import Query
from superset.utils import core as utils
from superset.utils.core import GenericDataType
@@ -169,9 +173,7 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
}
@classmethod
- def fetch_data(
- cls, cursor: Any, limit: Optional[int] = None
- ) -> list[tuple[Any, ...]]:
+ def fetch_data(cls, cursor: Any, limit: int | None = None) -> list[tuple[Any, ...]]:
if not cursor.description:
return []
return super().fetch_data(cursor, limit)
@@ -224,7 +226,7 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
cls,
sqlalchemy_uri: URL,
connect_args: dict[str, Any],
- ) -> Optional[str]:
+ ) -> str | None:
"""
Return the configured schema.
@@ -237,6 +239,9 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
in Superset because it breaks schema-level permissions, since it's impossible
to determine the schema for a non-qualified table in a query. In cases like
that we raise an exception.
+
+ Note that because the DB engine supports dynamic schema this method is never
+ called. It's left here as an implementation reference.
"""
options = parse_options(connect_args)
if search_path := options.get("search_path"):
@@ -252,23 +257,50 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
return None
@classmethod
- def adjust_engine_params(
+ def get_default_schema_for_query(
cls,
- uri: URL,
- connect_args: dict[str, Any],
- catalog: Optional[str] = None,
- schema: Optional[str] = None,
- ) -> tuple[URL, dict[str, Any]]:
- if not schema:
- return uri, connect_args
+ database: Database,
+ query: Query,
+ ) -> str | None:
+ """
+ Return the default schema for a given query.
- options = parse_options(connect_args)
- options["search_path"] = schema
- connect_args["options"] = " ".join(
- f"-c{key}={value}" for key, value in options.items()
- )
+ This method simply uses the parent method after checking that there are no
+ malicious path setting in the query.
+ """
+ sql = sqlparse.format(query.sql, strip_comments=True)
+ if re.search(r"set\s+search_path\s*=", sql, re.IGNORECASE):
+ raise SupersetSecurityException(
+ SupersetError(
+ error_type=SupersetErrorType.QUERY_SECURITY_ACCESS_ERROR,
+ message=__(
+ "Users are not allowed to set a search path for security reasons."
+ ),
+ level=ErrorLevel.ERROR,
+ )
+ )
+
+ return super().get_default_schema_for_query(database, query)
- return uri, connect_args
+ @classmethod
+ def get_prequeries(
+ cls,
+ catalog: str | None = None,
+ schema: str | None = None,
+ ) -> list[str]:
+ """
+ Set the search path to the specified schema.
+
+ This is important for two reasons: in SQL Lab it will allow queries to run in
+ the schema selected in the dropdown, resolving unqualified table names to the
+ expected schema.
+
+ But more importantly, in SQL Lab this is used to check if the user has access to
+ any tables with unqualified names. If the schema is not set by SQL Lab it could
+ be anything, and we would have to block users from running any queries
+ referencing tables without an explicit schema.
+ """
+ return [f'set search_path = "{schema}"'] if schema else []
@classmethod
def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool:
@@ -298,7 +330,7 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
@classmethod
def get_catalog_names(
cls,
- database: "Database",
+ database: Database,
inspector: Inspector,
) -> list[str]:
"""
@@ -318,7 +350,7 @@ WHERE datistemplate = false;
@classmethod
def get_table_names(
- cls, database: "Database", inspector: PGInspector, schema: Optional[str]
+ cls, database: Database, inspector: PGInspector, schema: str | None
) -> set[str]:
"""Need to consider foreign tables for PostgreSQL"""
return set(inspector.get_table_names(schema)) | set(
@@ -327,8 +359,8 @@ WHERE datistemplate = false;
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
- ) -> Optional[str]:
+ cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None
+ ) -> str | None:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, Date):
@@ -339,7 +371,7 @@ WHERE datistemplate = false;
return None
@staticmethod
- def get_extra_params(database: "Database") -> dict[str, Any]:
+ def get_extra_params(database: Database) -> dict[str, Any]:
"""
For Postgres, the path to a SSL certificate is placed in `connect_args`.
@@ -363,7 +395,7 @@ WHERE datistemplate = false;
return extra
@classmethod
- def get_datatype(cls, type_code: Any) -> Optional[str]:
+ def get_datatype(cls, type_code: Any) -> str | None:
# pylint: disable=import-outside-toplevel
from psycopg2.extensions import binary_types, string_types
@@ -374,7 +406,7 @@ WHERE datistemplate = false;
return None
@classmethod
- def get_cancel_query_id(cls, cursor: Any, query: Query) -> Optional[str]:
+ def get_cancel_query_id(cls, cursor: Any, query: Query) -> str | None:
"""
Get Postgres PID that will be used to cancel all other running
queries in the same session.
diff --git a/superset/models/core.py b/superset/models/core.py
index 4ff56145e1..a8fe8b5411 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -16,6 +16,9 @@
# under the License.
# pylint: disable=line-too-long,too-many-lines
"""A collection of ORM sqlalchemy models for Superset"""
+
+from __future__ import annotations
+
import builtins
import enum
import json
@@ -26,7 +29,7 @@ from contextlib import closing, contextmanager, nullcontext
from copy import deepcopy
from datetime import datetime
from functools import lru_cache
-from typing import Any, Callable, Optional, TYPE_CHECKING
+from typing import Any, Callable, TYPE_CHECKING
import numpy
import pandas as pd
@@ -270,7 +273,7 @@ class Database(
return self.url_object.get_driver_name()
@property
- def masked_encrypted_extra(self) -> Optional[str]:
+ def masked_encrypted_extra(self) -> str | None:
return self.db_engine_spec.mask_encrypted_extra(self.encrypted_extra)
@property
@@ -315,7 +318,7 @@ class Database(
return "schema_cache_timeout" in self.metadata_cache_timeout
@property
- def schema_cache_timeout(self) -> Optional[int]:
+ def schema_cache_timeout(self) -> int | None:
return self.metadata_cache_timeout.get("schema_cache_timeout")
@property
@@ -323,7 +326,7 @@ class Database(
return "table_cache_timeout" in self.metadata_cache_timeout
@property
- def table_cache_timeout(self) -> Optional[int]:
+ def table_cache_timeout(self) -> int | None:
return self.metadata_cache_timeout.get("table_cache_timeout")
@property
@@ -364,7 +367,7 @@ class Database(
conn = conn.set(password=PASSWORD_MASK if conn.password else None)
self.sqlalchemy_uri = str(conn) # hides the password
- def get_effective_user(self, object_url: URL) -> Optional[str]:
+ def get_effective_user(self, object_url: URL) -> str | None:
"""
Get the effective user, especially during impersonation.
@@ -383,10 +386,10 @@ class Database(
@contextmanager
def get_sqla_engine_with_context(
self,
- schema: Optional[str] = None,
+ schema: str | None = None,
nullpool: bool = True,
- source: Optional[utils.QuerySource] = None,
- override_ssh_tunnel: Optional["SSHTunnel"] = None,
+ source: utils.QuerySource | None = None,
+ override_ssh_tunnel: SSHTunnel | None = None,
) -> Engine:
from superset.daos.database import ( # pylint: disable=import-outside-toplevel
DatabaseDAO,
@@ -425,10 +428,10 @@ class Database(
def _get_sqla_engine(
self,
- schema: Optional[str] = None,
+ schema: str | None = None,
nullpool: bool = True,
- source: Optional[utils.QuerySource] = None,
- sqlalchemy_uri: Optional[str] = None,
+ source: utils.QuerySource | None = None,
+ sqlalchemy_uri: str | None = None,
) -> Engine:
sqlalchemy_url = make_url_safe(
sqlalchemy_uri if sqlalchemy_uri else self.sqlalchemy_uri_decrypted
@@ -513,17 +516,23 @@ class Database(
@contextmanager
def get_raw_connection(
self,
- schema: Optional[str] = None,
+ schema: str | None = None,
nullpool: bool = True,
- source: Optional[utils.QuerySource] = None,
+ source: utils.QuerySource | None = None,
) -> Connection:
with self.get_sqla_engine_with_context(
schema=schema, nullpool=nullpool, source=source
) as engine:
with closing(engine.raw_connection()) as conn:
+ # pre-session queries are used to set the selected schema and, in the
+ # future, the selected catalog
+ for prequery in self.db_engine_spec.get_prequeries(schema=schema):
+ cursor = conn.cursor()
+ cursor.execute(prequery)
+
yield conn
- def get_default_schema_for_query(self, query: "Query") -> Optional[str]:
+ def get_default_schema_for_query(self, query: Query) -> str | None:
"""
Return the default schema for a given query.
@@ -550,8 +559,8 @@ class Database(
def get_df( # pylint: disable=too-many-locals
self,
sql: str,
- schema: Optional[str] = None,
- mutator: Optional[Callable[[pd.DataFrame], None]] = None,
+ schema: str | None = None,
+ mutator: Callable[[pd.DataFrame], None] | None = None,
) -> pd.DataFrame:
sqls = self.db_engine_spec.parse_sql(sql)
engine = self._get_sqla_engine(schema)
@@ -614,7 +623,7 @@ class Database(
return df
- def compile_sqla_query(self, qry: Select, schema: Optional[str] = None) -> str:
+ def compile_sqla_query(self, qry: Select, schema: str | None = None) -> str:
engine = self._get_sqla_engine(schema=schema)
sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True}))
@@ -628,12 +637,12 @@ class Database(
def select_star( # pylint: disable=too-many-arguments
self,
table_name: str,
- schema: Optional[str] = None,
+ schema: str | None = None,
limit: int = 100,
show_cols: bool = False,
indent: bool = True,
latest_partition: bool = False,
- cols: Optional[list[ResultSetColumnType]] = None,
+ cols: list[ResultSetColumnType] | None = None,
) -> str:
"""Generates a ``select *`` statement in the proper dialect"""
eng = self._get_sqla_engine(schema=schema, source=utils.QuerySource.SQL_LAB)
@@ -672,7 +681,7 @@ class Database(
self,
schema: str,
cache: bool = False,
- cache_timeout: Optional[int] = None,
+ cache_timeout: int | None = None,
force: bool = False,
) -> set[tuple[str, str]]:
"""Parameters need to be passed as keyword arguments.
@@ -708,7 +717,7 @@ class Database(
self,
schema: str,
cache: bool = False,
- cache_timeout: Optional[int] = None,
+ cache_timeout: int | None = None,
force: bool = False,
) -> set[tuple[str, str]]:
"""Parameters need to be passed as keyword arguments.
@@ -737,7 +746,7 @@ class Database(
@contextmanager
def get_inspector_with_context(
- self, ssh_tunnel: Optional["SSHTunnel"] = None
+ self, ssh_tunnel: SSHTunnel | None = None
) -> Inspector:
with self.get_sqla_engine_with_context(
override_ssh_tunnel=ssh_tunnel
@@ -751,9 +760,9 @@ class Database(
def get_all_schema_names( # pylint: disable=unused-argument
self,
cache: bool = False,
- cache_timeout: Optional[int] = None,
+ cache_timeout: int | None = None,
force: bool = False,
- ssh_tunnel: Optional["SSHTunnel"] = None,
+ ssh_tunnel: SSHTunnel | None = None,
) -> list[str]:
"""Parameters need to be passed as keyword arguments.
@@ -818,7 +827,7 @@ class Database(
def update_params_from_encrypted_extra(self, params: dict[str, Any]) -> None:
self.db_engine_spec.update_params_from_encrypted_extra(self, params)
- def get_table(self, table_name: str, schema: Optional[str] = None) -> Table:
+ def get_table(self, table_name: str, schema: str | None = None) -> Table:
extra = self.get_extra()
meta = MetaData(**extra.get("metadata_params", {}))
with self.get_sqla_engine_with_context() as engine:
@@ -831,13 +840,13 @@ class Database(
)
def get_table_comment(
- self, table_name: str, schema: Optional[str] = None
- ) -> Optional[str]:
+ self, table_name: str, schema: str | None = None
+ ) -> str | None:
with self.get_inspector_with_context() as inspector:
return self.db_engine_spec.get_table_comment(inspector, table_name, schema)
def get_columns(
- self, table_name: str, schema: Optional[str] = None
+ self, table_name: str, schema: str | None = None
) -> list[ResultSetColumnType]:
with self.get_inspector_with_context() as inspector:
return self.db_engine_spec.get_columns(inspector, table_name, schema)
@@ -845,19 +854,19 @@ class Database(
def get_metrics(
self,
table_name: str,
- schema: Optional[str] = None,
+ schema: str | None = None,
) -> list[MetricType]:
with self.get_inspector_with_context() as inspector:
return self.db_engine_spec.get_metrics(self, inspector, table_name, schema)
def get_indexes(
- self, table_name: str, schema: Optional[str] = None
+ self, table_name: str, schema: str | None = None
) -> list[dict[str, Any]]:
with self.get_inspector_with_context() as inspector:
return self.db_engine_spec.get_indexes(self, inspector, table_name, schema)
def get_pk_constraint(
- self, table_name: str, schema: Optional[str] = None
+ self, table_name: str, schema: str | None = None
) -> dict[str, Any]:
with self.get_inspector_with_context() as inspector:
pk_constraint = inspector.get_pk_constraint(table_name, schema) or {}
@@ -871,7 +880,7 @@ class Database(
return {key: _convert(value) for key, value in pk_constraint.items()}
def get_foreign_keys(
- self, table_name: str, schema: Optional[str] = None
+ self, table_name: str, schema: str | None = None
) -> list[dict[str, Any]]:
with self.get_inspector_with_context() as inspector:
return inspector.get_foreign_keys(table_name, schema)
@@ -926,7 +935,7 @@ class Database(
with self.get_sqla_engine_with_context() as engine:
return engine.has_table(table.table_name, table.schema or None)
- def has_table_by_name(self, table_name: str, schema: Optional[str] = None) -> bool:
+ def has_table_by_name(self, table_name: str, schema: str | None = None) -> bool:
with self.get_sqla_engine_with_context() as engine:
return engine.has_table(table_name, schema)
@@ -936,7 +945,7 @@ class Database(
conn: Connection,
dialect: Dialect,
view_name: str,
- schema: Optional[str] = None,
+ schema: str | None = None,
) -> bool:
view_names: list[str] = []
try:
@@ -945,11 +954,11 @@ class Database(
logger.warning("Has view failed", exc_info=True)
return view_name in view_names
- def has_view(self, view_name: str, schema: Optional[str] = None) -> bool:
+ def has_view(self, view_name: str, schema: str | None = None) -> bool:
engine = self._get_sqla_engine()
return engine.run_callable(self._has_view, engine.dialect, view_name, schema)
- def has_view_by_name(self, view_name: str, schema: Optional[str] = None) -> bool:
+ def has_view_by_name(self, view_name: str, schema: str | None = None) -> bool:
return self.has_view(view_name=view_name, schema=schema)
def get_dialect(self) -> Dialect:
@@ -957,7 +966,7 @@ class Database(
return sqla_url.get_dialect()()
def make_sqla_column_compatible(
- self, sqla_col: ColumnElement, label: Optional[str] = None
+ self, sqla_col: ColumnElement, label: str | None = None
) -> ColumnElement:
"""Takes a sqlalchemy column object and adds label info if supported by engine.
:param sqla_col: sqlalchemy column instance
diff --git a/tests/unit_tests/db_engine_specs/test_postgres.py b/tests/unit_tests/db_engine_specs/test_postgres.py
index 145d398898..59d1829f14 100644
--- a/tests/unit_tests/db_engine_specs/test_postgres.py
+++ b/tests/unit_tests/db_engine_specs/test_postgres.py
@@ -19,10 +19,12 @@ from datetime import datetime
from typing import Any, Optional
import pytest
+from pytest_mock import MockFixture
from sqlalchemy import types
from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION, ENUM, JSON
from sqlalchemy.engine.url import make_url
+from superset.exceptions import SupersetSecurityException
from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import (
assert_column_spec,
@@ -133,25 +135,41 @@ def test_get_schema_from_engine_params() -> None:
)
-def test_adjust_engine_params() -> None:
+def test_get_prequeries() -> None:
"""
- Test the ``adjust_engine_params`` method.
+ Test the ``get_prequeries`` method.
"""
from superset.db_engine_specs.postgres import PostgresEngineSpec
- uri = make_url("postgres://user:password@host/catalog")
+ assert PostgresEngineSpec.get_prequeries() == []
+ assert PostgresEngineSpec.get_prequeries(schema="test") == [
+ 'set search_path = "test"'
+ ]
- assert PostgresEngineSpec.adjust_engine_params(uri, {}, None, "secret") == (
- uri,
- {"options": "-csearch_path=secret"},
- )
- assert PostgresEngineSpec.adjust_engine_params(
- uri,
- {"foo": "bar", "options": "-csearch_path=default -c debug=1"},
- None,
- "secret",
- ) == (
- uri,
- {"foo": "bar", "options": "-csearch_path=secret -cdebug=1"},
+def test_get_default_schema_for_query(mocker: MockFixture) -> None:
+ """
+ Test the ``get_default_schema_for_query`` method.
+ """
+ from superset.db_engine_specs.postgres import PostgresEngineSpec
+
+ database = mocker.MagicMock()
+ query = mocker.MagicMock()
+
+ query.sql = "SELECT * FROM some_table"
+ query.schema = "foo"
+ assert PostgresEngineSpec.get_default_schema_for_query(database, query) == "foo"
+
+ query.sql = """
+set
+-- this is a tricky comment
+search_path -- another one
+= bar;
+SELECT * FROM some_table;
+ """
+ with pytest.raises(SupersetSecurityException) as excinfo:
+ PostgresEngineSpec.get_default_schema_for_query(database, query)
+ assert (
+ str(excinfo.value)
+ == "Users are not allowed to set a search path for security reasons."
)
diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py
index 267b7c024a..5d6c1fcbcc 100644
--- a/tests/unit_tests/models/core_test.py
+++ b/tests/unit_tests/models/core_test.py
@@ -212,3 +212,21 @@ def test_dttm_sql_literal(
def test_table_column_database() -> None:
database = Database(database_name="db")
assert TableColumn(database=database).database is database
+
+
+def test_get_prequeries(mocker: MockFixture) -> None:
+ """
+ Tests for ``get_prequeries``.
+ """
+ mocker.patch.object(
+ Database,
+ "get_sqla_engine_with_context",
+ )
+ db_engine_spec = mocker.patch.object(Database, "db_engine_spec")
+ db_engine_spec.get_prequeries.return_value = ["set a=1", "set b=2"]
+
+ database = Database(database_name="db")
+ with database.get_raw_connection() as conn:
+ conn.cursor().execute.assert_has_calls(
+ [mocker.call("set a=1"), mocker.call("set b=2")]
+ )