You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by am...@apache.org on 2021/11/22 08:02:40 UTC
[superset] branch master updated: refactor(QueryContext): separate all processing concerns into QueryContextProcessor (#17497)
This is an automated email from the ASF dual-hosted git repository.
amitmiran pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new 66d7569 refactor(QueryContext): separate all processing concerns into QueryContextProcessor (#17497)
66d7569 is described below
commit 66d756955b961ed8dd379783f453d1f2531fad9e
Author: ofekisr <35...@users.noreply.github.com>
AuthorDate: Mon Nov 22 10:01:01 2021 +0200
refactor(QueryContext): separate all processing concerns into QueryContextProcessor (#17497)
---
superset/common/query_context.py | 414 ++-------------------
...query_context.py => query_context_processor.py} | 390 ++++++++++---------
tests/integration_tests/charts/schema_tests.py | 3 +-
3 files changed, 213 insertions(+), 594 deletions(-)
diff --git a/superset/common/query_context.py b/superset/common/query_context.py
index b1677c0..5ddd174 100644
--- a/superset/common/query_context.py
+++ b/superset/common/query_context.py
@@ -16,56 +16,24 @@
# under the License.
from __future__ import annotations
-import copy
import logging
from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Union
-import numpy as np
import pandas as pd
-from flask_babel import _
-from pandas import DateOffset
-from typing_extensions import TypedDict
-from superset import app, is_feature_enabled
-from superset.annotation_layers.dao import AnnotationLayerDAO
-from superset.charts.dao import ChartDAO
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
-from superset.common.db_query_status import QueryStatus
-from superset.common.query_actions import get_query_results
-from superset.common.query_object import QueryObject
-from superset.common.utils import dataframe_utils as df_utils
-from superset.common.utils.query_cache_manager import QueryCacheManager
-from superset.constants import CacheRegion
-from superset.exceptions import QueryObjectValidationError, SupersetException
-from superset.extensions import cache_manager, security_manager
-from superset.models.helpers import QueryResult
-from superset.utils import csv
-from superset.utils.cache import generate_cache_key, set_and_log_cache
-from superset.utils.core import (
- DTTM_ALIAS,
- error_msg_from_exception,
- get_column_names_from_columns,
- get_column_names_from_metrics,
- get_metric_names,
- normalize_dttm_col,
- TIME_COMPARISION,
+from superset.common.query_context_processor import (
+ CachedTimeOffset,
+ QueryContextProcessor,
)
-from superset.utils.date_parser import get_past_or_future, normalize_time_delta
-from superset.views.utils import get_viz
+from superset.common.query_object import QueryObject
if TYPE_CHECKING:
from superset.connectors.base.models import BaseDatasource
- from superset.stats_logger import BaseStatsLogger
-
-config = app.config
-stats_logger: BaseStatsLogger = config["STATS_LOGGER"]
-logger = logging.getLogger(__name__)
+ from superset.models.helpers import QueryResult
-class CachedTimeOffset(TypedDict):
- df: pd.DataFrame
- queries: List[str]
- cache_keys: List[Optional[str]]
+logger = logging.getLogger(__name__)
class QueryContext:
@@ -86,6 +54,8 @@ class QueryContext:
cache_values: Dict[str, Any]
+ _processor: QueryContextProcessor
+
# TODO: Type datasource and query_object dictionary with TypedDict when it becomes
# a vanilla python type https://github.com/python/mypy/issues/5288
def __init__(
@@ -106,379 +76,41 @@ class QueryContext:
self.force = force
self.custom_cache_timeout = custom_cache_timeout
self.cache_values = cache_values
-
- def processing_time_offsets( # pylint: disable=too-many-locals
- self, df: pd.DataFrame, query_object: QueryObject,
- ) -> CachedTimeOffset:
- # ensure query_object is immutable
- query_object_clone = copy.copy(query_object)
- queries: List[str] = []
- cache_keys: List[Optional[str]] = []
- rv_dfs: List[pd.DataFrame] = [df]
-
- time_offsets = query_object.time_offsets
- outer_from_dttm = query_object.from_dttm
- outer_to_dttm = query_object.to_dttm
- for offset in time_offsets:
- try:
- query_object_clone.from_dttm = get_past_or_future(
- offset, outer_from_dttm,
- )
- query_object_clone.to_dttm = get_past_or_future(offset, outer_to_dttm)
- except ValueError as ex:
- raise QueryObjectValidationError(str(ex)) from ex
- # make sure subquery use main query where clause
- query_object_clone.inner_from_dttm = outer_from_dttm
- query_object_clone.inner_to_dttm = outer_to_dttm
- query_object_clone.time_offsets = []
- query_object_clone.post_processing = []
-
- if not query_object.from_dttm or not query_object.to_dttm:
- raise QueryObjectValidationError(
- _(
- "An enclosed time range (both start and end) must be specified "
- "when using a Time Comparison."
- )
- )
- # `offset` is added to the hash function
- cache_key = self.query_cache_key(query_object_clone, time_offset=offset)
- cache = QueryCacheManager.get(cache_key, CacheRegion.DATA, self.force)
- # whether hit on the cache
- if cache.is_loaded:
- rv_dfs.append(cache.df)
- queries.append(cache.query)
- cache_keys.append(cache_key)
- continue
-
- query_object_clone_dct = query_object_clone.to_dict()
- # rename metrics: SUM(value) => SUM(value) 1 year ago
- metrics_mapping = {
- metric: TIME_COMPARISION.join([metric, offset])
- for metric in get_metric_names(
- query_object_clone_dct.get("metrics", [])
- )
- }
- join_keys = [col for col in df.columns if col not in metrics_mapping.keys()]
-
- result = self.datasource.query(query_object_clone_dct)
- queries.append(result.query)
- cache_keys.append(None)
-
- offset_metrics_df = result.df
- if offset_metrics_df.empty:
- offset_metrics_df = pd.DataFrame(
- {
- col: [np.NaN]
- for col in join_keys + list(metrics_mapping.values())
- }
- )
- else:
- # 1. normalize df, set dttm column
- offset_metrics_df = self.normalize_df(
- offset_metrics_df, query_object_clone
- )
-
- # 2. rename extra query columns
- offset_metrics_df = offset_metrics_df.rename(columns=metrics_mapping)
-
- # 3. set time offset for dttm column
- offset_metrics_df[DTTM_ALIAS] = offset_metrics_df[
- DTTM_ALIAS
- ] - DateOffset(**normalize_time_delta(offset))
-
- # df left join `offset_metrics_df`
- offset_df = df_utils.left_join_df(
- left_df=df, right_df=offset_metrics_df, join_keys=join_keys,
- )
- offset_slice = offset_df[metrics_mapping.values()]
-
- # set offset_slice to cache and stack.
- value = {
- "df": offset_slice,
- "query": result.query,
- }
- cache.set(
- key=cache_key,
- value=value,
- timeout=self.cache_timeout,
- datasource_uid=self.datasource.uid,
- region=CacheRegion.DATA,
- )
- rv_dfs.append(offset_slice)
-
- rv_df = pd.concat(rv_dfs, axis=1, copy=False) if time_offsets else df
- return CachedTimeOffset(df=rv_df, queries=queries, cache_keys=cache_keys)
-
- def normalize_df(self, df: pd.DataFrame, query_object: QueryObject) -> pd.DataFrame:
- timestamp_format = None
- if self.datasource.type == "table":
- dttm_col = self.datasource.get_column(query_object.granularity)
- if dttm_col:
- timestamp_format = dttm_col.python_date_format
-
- normalize_dttm_col(
- df=df,
- timestamp_format=timestamp_format,
- offset=self.datasource.offset,
- time_shift=query_object.time_shift,
- )
-
- if self.enforce_numerical_metrics:
- df_utils.df_metrics_to_num(df, query_object)
-
- df.replace([np.inf, -np.inf], np.nan, inplace=True)
-
- return df
-
- def get_query_result(self, query_object: QueryObject) -> QueryResult:
- """Returns a pandas dataframe based on the query object"""
-
- # Here, we assume that all the queries will use the same datasource, which is
- # a valid assumption for current setting. In the long term, we may
- # support multiple queries from different data sources.
-
- # The datasource here can be different backend but the interface is common
- result = self.datasource.query(query_object.to_dict())
- query = result.query + ";\n\n"
-
- df = result.df
- # Transform the timestamp we received from database to pandas supported
- # datetime format. If no python_date_format is specified, the pattern will
- # be considered as the default ISO date format
- # If the datetime format is unix, the parse will use the corresponding
- # parsing logic
- if not df.empty:
- df = self.normalize_df(df, query_object)
-
- if query_object.time_offsets:
- time_offsets = self.processing_time_offsets(df, query_object)
- df = time_offsets["df"]
- queries = time_offsets["queries"]
-
- query += ";\n\n".join(queries)
- query += ";\n\n"
-
- df = query_object.exec_post_processing(df)
-
- result.df = df
- result.query = query
- return result
+ self._processor = QueryContextProcessor(self)
def get_data(self, df: pd.DataFrame,) -> Union[str, List[Dict[str, Any]]]:
- if self.result_format == ChartDataResultFormat.CSV:
- include_index = not isinstance(df.index, pd.RangeIndex)
- result = csv.df_to_escaped_csv(
- df, index=include_index, **config["CSV_EXPORT"]
- )
- return result or ""
-
- return df.to_dict(orient="records")
+ return self._processor.get_data(df)
def get_payload(
self, cache_query_context: Optional[bool] = False, force_cached: bool = False,
) -> Dict[str, Any]:
"""Returns the query results with both metadata and data"""
+ return self._processor.get_payload(cache_query_context, force_cached)
- # Get all the payloads from the QueryObjects
- query_results = [
- get_query_results(
- query_obj.result_type or self.result_type, self, query_obj, force_cached
- )
- for query_obj in self.queries
- ]
- return_value = {"queries": query_results}
-
- if cache_query_context:
- cache_key = self.cache_key()
- set_and_log_cache(
- cache_manager.cache,
- cache_key,
- {"data": self.cache_values},
- self.cache_timeout,
- )
- return_value["cache_key"] = cache_key # type: ignore
-
- return return_value
-
- @property
- def cache_timeout(self) -> int:
+ def get_cache_timeout(self) -> Optional[int]:
if self.custom_cache_timeout is not None:
return self.custom_cache_timeout
if self.datasource.cache_timeout is not None:
return self.datasource.cache_timeout
- if (
- hasattr(self.datasource, "database")
- and self.datasource.database.cache_timeout
- ) is not None:
+ if hasattr(self.datasource, "database"):
return self.datasource.database.cache_timeout
- return config["CACHE_DEFAULT_TIMEOUT"]
-
- def cache_key(self, **extra: Any) -> str:
- """
- The QueryContext cache key is made out of the key/values from
- self.cached_values, plus any other key/values in `extra`. It includes only data
- required to rehydrate a QueryContext object.
- """
- key_prefix = "qc-"
- cache_dict = self.cache_values.copy()
- cache_dict.update(extra)
-
- return generate_cache_key(cache_dict, key_prefix)
+ return None
def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> Optional[str]:
- """
- Returns a QueryObject cache key for objects in self.queries
- """
- extra_cache_keys = self.datasource.get_extra_cache_keys(query_obj.to_dict())
-
- cache_key = (
- query_obj.cache_key(
- datasource=self.datasource.uid,
- extra_cache_keys=extra_cache_keys,
- rls=security_manager.get_rls_ids(self.datasource)
- if is_feature_enabled("ROW_LEVEL_SECURITY")
- and self.datasource.is_rls_supported
- else [],
- changed_on=self.datasource.changed_on,
- **kwargs,
- )
- if query_obj
- else None
- )
- return cache_key
-
- @staticmethod
- def get_native_annotation_data(query_obj: QueryObject) -> Dict[str, Any]:
- annotation_data = {}
- annotation_layers = [
- layer
- for layer in query_obj.annotation_layers
- if layer["sourceType"] == "NATIVE"
- ]
- layer_ids = [layer["value"] for layer in annotation_layers]
- layer_objects = {
- layer_object.id: layer_object
- for layer_object in AnnotationLayerDAO.find_by_ids(layer_ids)
- }
-
- # annotations
- for layer in annotation_layers:
- layer_id = layer["value"]
- layer_name = layer["name"]
- columns = [
- "start_dttm",
- "end_dttm",
- "short_descr",
- "long_descr",
- "json_metadata",
- ]
- layer_object = layer_objects[layer_id]
- records = [
- {column: getattr(annotation, column) for column in columns}
- for annotation in layer_object.annotation
- ]
- result = {"columns": columns, "records": records}
- annotation_data[layer_name] = result
- return annotation_data
-
- @staticmethod
- def get_viz_annotation_data(
- annotation_layer: Dict[str, Any], force: bool
- ) -> Dict[str, Any]:
- chart = ChartDAO.find_by_id(annotation_layer["value"])
- form_data = chart.form_data.copy()
- if not chart:
- raise QueryObjectValidationError(_("The chart does not exist"))
- try:
- viz_obj = get_viz(
- datasource_type=chart.datasource.type,
- datasource_id=chart.datasource.id,
- form_data=form_data,
- force=force,
- )
- payload = viz_obj.get_payload()
- return payload["data"]
- except SupersetException as ex:
- raise QueryObjectValidationError(error_msg_from_exception(ex)) from ex
-
- def get_annotation_data(self, query_obj: QueryObject) -> Dict[str, Any]:
- """
-
- :param query_obj:
- :return:
- """
- annotation_data: Dict[str, Any] = self.get_native_annotation_data(query_obj)
- for annotation_layer in [
- layer
- for layer in query_obj.annotation_layers
- if layer["sourceType"] in ("line", "table")
- ]:
- name = annotation_layer["name"]
- annotation_data[name] = self.get_viz_annotation_data(
- annotation_layer, self.force
- )
- return annotation_data
+ return self._processor.query_cache_key(query_obj, **kwargs)
def get_df_payload(
self, query_obj: QueryObject, force_cached: Optional[bool] = False,
) -> Dict[str, Any]:
- """Handles caching around the df payload retrieval"""
- cache_key = self.query_cache_key(query_obj)
- cache = QueryCacheManager.get(
- cache_key, CacheRegion.DATA, self.force, force_cached,
- )
+ return self._processor.get_df_payload(query_obj, force_cached)
- if query_obj and cache_key and not cache.is_loaded:
- try:
- invalid_columns = [
- col
- for col in get_column_names_from_columns(query_obj.columns)
- + get_column_names_from_metrics(query_obj.metrics or [])
- if col not in self.datasource.column_names and col != DTTM_ALIAS
- ]
- if invalid_columns:
- raise QueryObjectValidationError(
- _(
- "Columns missing in datasource: %(invalid_columns)s",
- invalid_columns=invalid_columns,
- )
- )
- query_result = self.get_query_result(query_obj)
- annotation_data = self.get_annotation_data(query_obj)
- cache.set_query_result(
- key=cache_key,
- query_result=query_result,
- annotation_data=annotation_data,
- force_query=self.force,
- timeout=self.cache_timeout,
- datasource_uid=self.datasource.uid,
- region=CacheRegion.DATA,
- )
- except QueryObjectValidationError as ex:
- cache.error_message = str(ex)
- cache.status = QueryStatus.FAILED
+ def get_query_result(self, query_object: QueryObject) -> QueryResult:
+ return self._processor.get_query_result(query_object)
- return {
- "cache_key": cache_key,
- "cached_dttm": cache.cache_dttm,
- "cache_timeout": self.cache_timeout,
- "df": cache.df,
- "applied_template_filters": cache.applied_template_filters,
- "annotation_data": cache.annotation_data,
- "error": cache.error_message,
- "is_cached": cache.is_cached,
- "query": cache.query,
- "status": cache.status,
- "stacktrace": cache.stacktrace,
- "rowcount": len(cache.df.index),
- }
+ def processing_time_offsets(
+ self, df: pd.DataFrame, query_object: QueryObject,
+ ) -> CachedTimeOffset:
+ return self._processor.processing_time_offsets(df, query_object)
def raise_for_access(self) -> None:
- """
- Raise an exception if the user cannot access the resource.
-
- :raises SupersetSecurityException: If the user cannot access the resource
- """
- for query in self.queries:
- query.validate()
- security_manager.raise_for_access(query_context=self)
+ self._processor.raise_for_access()
diff --git a/superset/common/query_context.py b/superset/common/query_context_processor.py
similarity index 81%
copy from superset/common/query_context.py
copy to superset/common/query_context_processor.py
index b1677c0..d1328dd 100644
--- a/superset/common/query_context.py
+++ b/superset/common/query_context_processor.py
@@ -29,12 +29,12 @@ from typing_extensions import TypedDict
from superset import app, is_feature_enabled
from superset.annotation_layers.dao import AnnotationLayerDAO
from superset.charts.dao import ChartDAO
-from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
+from superset.common.chart_data import ChartDataResultFormat
from superset.common.db_query_status import QueryStatus
from superset.common.query_actions import get_query_results
-from superset.common.query_object import QueryObject
from superset.common.utils import dataframe_utils as df_utils
from superset.common.utils.query_cache_manager import QueryCacheManager
+from superset.connectors.base.models import BaseDatasource
from superset.constants import CacheRegion
from superset.exceptions import QueryObjectValidationError, SupersetException
from superset.extensions import cache_manager, security_manager
@@ -54,7 +54,8 @@ from superset.utils.date_parser import get_past_or_future, normalize_time_delta
from superset.views.utils import get_viz
if TYPE_CHECKING:
- from superset.connectors.base.models import BaseDatasource
+ from superset.common.query_context import QueryContext
+ from superset.common.query_object import QueryObject
from superset.stats_logger import BaseStatsLogger
config = app.config
@@ -68,48 +69,166 @@ class CachedTimeOffset(TypedDict):
cache_keys: List[Optional[str]]
-class QueryContext:
+class QueryContextProcessor:
"""
The query context contains the query object and additional fields necessary
to retrieve the data payload for a given viz.
"""
+ _query_context: QueryContext
+ _qc_datasource: BaseDatasource
+ """
+ The query context contains the query object and additional fields necessary
+ to retrieve the data payload for a given viz.
+ """
+
+ def __init__(self, query_context: QueryContext):
+ self._query_context = query_context
+ self._qc_datasource = query_context.datasource
+
cache_type: ClassVar[str] = "df"
enforce_numerical_metrics: ClassVar[bool] = True
- datasource: BaseDatasource
- queries: List[QueryObject]
- result_type: ChartDataResultType
- result_format: ChartDataResultFormat
- force: bool
- custom_cache_timeout: Optional[int]
-
- cache_values: Dict[str, Any]
-
- # TODO: Type datasource and query_object dictionary with TypedDict when it becomes
- # a vanilla python type https://github.com/python/mypy/issues/5288
- def __init__(
- self,
- *,
- datasource: BaseDatasource,
- queries: List[QueryObject],
- result_type: ChartDataResultType,
- result_format: ChartDataResultFormat,
- force: bool = False,
- custom_cache_timeout: Optional[int] = None,
- cache_values: Dict[str, Any]
- ) -> None:
- self.datasource = datasource
- self.result_type = result_type
- self.result_format = result_format
- self.queries = queries
- self.force = force
- self.custom_cache_timeout = custom_cache_timeout
- self.cache_values = cache_values
+ def get_df_payload(
+ self, query_obj: QueryObject, force_cached: Optional[bool] = False
+ ) -> Dict[str, Any]:
+ """Handles caching around the df payload retrieval"""
+ cache_key = self.query_cache_key(query_obj)
+ cache = QueryCacheManager.get(
+ cache_key, CacheRegion.DATA, self._query_context.force, force_cached,
+ )
+
+ if query_obj and cache_key and not cache.is_loaded:
+ try:
+ invalid_columns = [
+ col
+ for col in get_column_names_from_columns(query_obj.columns)
+ + get_column_names_from_metrics(query_obj.metrics or [])
+ if (
+ col not in self._qc_datasource.column_names
+ and col != DTTM_ALIAS
+ )
+ ]
+ if invalid_columns:
+ raise QueryObjectValidationError(
+ _(
+ "Columns missing in datasource: %(invalid_columns)s",
+ invalid_columns=invalid_columns,
+ )
+ )
+ query_result = self.get_query_result(query_obj)
+ annotation_data = self.get_annotation_data(query_obj)
+ cache.set_query_result(
+ key=cache_key,
+ query_result=query_result,
+ annotation_data=annotation_data,
+ force_query=self._query_context.force,
+ timeout=self.get_cache_timeout(),
+ datasource_uid=self._qc_datasource.uid,
+ region=CacheRegion.DATA,
+ )
+ except QueryObjectValidationError as ex:
+ cache.error_message = str(ex)
+ cache.status = QueryStatus.FAILED
+
+ return {
+ "cache_key": cache_key,
+ "cached_dttm": cache.cache_dttm,
+ "cache_timeout": self.get_cache_timeout(),
+ "df": cache.df,
+ "applied_template_filters": cache.applied_template_filters,
+ "annotation_data": cache.annotation_data,
+ "error": cache.error_message,
+ "is_cached": cache.is_cached,
+ "query": cache.query,
+ "status": cache.status,
+ "stacktrace": cache.stacktrace,
+ "rowcount": len(cache.df.index),
+ }
+
+ def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> Optional[str]:
+ """
+ Returns a QueryObject cache key for objects in self.queries
+ """
+ datasource = self._qc_datasource
+ extra_cache_keys = datasource.get_extra_cache_keys(query_obj.to_dict())
+
+ cache_key = (
+ query_obj.cache_key(
+ datasource=datasource.uid,
+ extra_cache_keys=extra_cache_keys,
+ rls=security_manager.get_rls_ids(datasource)
+ if is_feature_enabled("ROW_LEVEL_SECURITY")
+ and datasource.is_rls_supported
+ else [],
+ changed_on=datasource.changed_on,
+ **kwargs,
+ )
+ if query_obj
+ else None
+ )
+ return cache_key
+
+ def get_query_result(self, query_object: QueryObject) -> QueryResult:
+ """Returns a pandas dataframe based on the query object"""
+ query_context = self._query_context
+ # Here, we assume that all the queries will use the same datasource, which is
+ # a valid assumption for current setting. In the long term, we may
+ # support multiple queries from different data sources.
+
+ # The datasource here can be different backend but the interface is common
+ result = query_context.datasource.query(query_object.to_dict())
+ query = result.query + ";\n\n"
+
+ df = result.df
+ # Transform the timestamp we received from database to pandas supported
+ # datetime format. If no python_date_format is specified, the pattern will
+ # be considered as the default ISO date format
+ # If the datetime format is unix, the parse will use the corresponding
+ # parsing logic
+ if not df.empty:
+ df = self.normalize_df(df, query_object)
+
+ if query_object.time_offsets:
+ time_offsets = self.processing_time_offsets(df, query_object)
+ df = time_offsets["df"]
+ queries = time_offsets["queries"]
+
+ query += ";\n\n".join(queries)
+ query += ";\n\n"
+
+ df = query_object.exec_post_processing(df)
+
+ result.df = df
+ result.query = query
+ return result
+
+ def normalize_df(self, df: pd.DataFrame, query_object: QueryObject) -> pd.DataFrame:
+ datasource = self._qc_datasource
+ timestamp_format = None
+ if datasource.type == "table":
+ dttm_col = datasource.get_column(query_object.granularity)
+ if dttm_col:
+ timestamp_format = dttm_col.python_date_format
+
+ normalize_dttm_col(
+ df=df,
+ timestamp_format=timestamp_format,
+ offset=datasource.offset,
+ time_shift=query_object.time_shift,
+ )
+
+ if self.enforce_numerical_metrics:
+ df_utils.df_metrics_to_num(df, query_object)
+
+ df.replace([np.inf, -np.inf], np.nan, inplace=True)
+
+ return df
def processing_time_offsets( # pylint: disable=too-many-locals
self, df: pd.DataFrame, query_object: QueryObject,
) -> CachedTimeOffset:
+ query_context = self._query_context
# ensure query_object is immutable
query_object_clone = copy.copy(query_object)
queries: List[str] = []
@@ -142,7 +261,9 @@ class QueryContext:
)
# `offset` is added to the hash function
cache_key = self.query_cache_key(query_object_clone, time_offset=offset)
- cache = QueryCacheManager.get(cache_key, CacheRegion.DATA, self.force)
+ cache = QueryCacheManager.get(
+ cache_key, CacheRegion.DATA, query_context.force
+ )
# whether hit on the cache
if cache.is_loaded:
rv_dfs.append(cache.df)
@@ -160,7 +281,7 @@ class QueryContext:
}
join_keys = [col for col in df.columns if col not in metrics_mapping.keys()]
- result = self.datasource.query(query_object_clone_dct)
+ result = self._qc_datasource.query(query_object_clone_dct)
queries.append(result.query)
cache_keys.append(None)
@@ -200,8 +321,8 @@ class QueryContext:
cache.set(
key=cache_key,
value=value,
- timeout=self.cache_timeout,
- datasource_uid=self.datasource.uid,
+ timeout=self.get_cache_timeout(),
+ datasource_uid=query_context.datasource.uid,
region=CacheRegion.DATA,
)
rv_dfs.append(offset_slice)
@@ -209,63 +330,8 @@ class QueryContext:
rv_df = pd.concat(rv_dfs, axis=1, copy=False) if time_offsets else df
return CachedTimeOffset(df=rv_df, queries=queries, cache_keys=cache_keys)
- def normalize_df(self, df: pd.DataFrame, query_object: QueryObject) -> pd.DataFrame:
- timestamp_format = None
- if self.datasource.type == "table":
- dttm_col = self.datasource.get_column(query_object.granularity)
- if dttm_col:
- timestamp_format = dttm_col.python_date_format
-
- normalize_dttm_col(
- df=df,
- timestamp_format=timestamp_format,
- offset=self.datasource.offset,
- time_shift=query_object.time_shift,
- )
-
- if self.enforce_numerical_metrics:
- df_utils.df_metrics_to_num(df, query_object)
-
- df.replace([np.inf, -np.inf], np.nan, inplace=True)
-
- return df
-
- def get_query_result(self, query_object: QueryObject) -> QueryResult:
- """Returns a pandas dataframe based on the query object"""
-
- # Here, we assume that all the queries will use the same datasource, which is
- # a valid assumption for current setting. In the long term, we may
- # support multiple queries from different data sources.
-
- # The datasource here can be different backend but the interface is common
- result = self.datasource.query(query_object.to_dict())
- query = result.query + ";\n\n"
-
- df = result.df
- # Transform the timestamp we received from database to pandas supported
- # datetime format. If no python_date_format is specified, the pattern will
- # be considered as the default ISO date format
- # If the datetime format is unix, the parse will use the corresponding
- # parsing logic
- if not df.empty:
- df = self.normalize_df(df, query_object)
-
- if query_object.time_offsets:
- time_offsets = self.processing_time_offsets(df, query_object)
- df = time_offsets["df"]
- queries = time_offsets["queries"]
-
- query += ";\n\n".join(queries)
- query += ";\n\n"
-
- df = query_object.exec_post_processing(df)
-
- result.df = df
- result.query = query
- return result
-
- def get_data(self, df: pd.DataFrame,) -> Union[str, List[Dict[str, Any]]]:
- if self.result_format == ChartDataResultFormat.CSV:
+ def get_data(self, df: pd.DataFrame) -> Union[str, List[Dict[str, Any]]]:
+ if self._query_context.result_format == ChartDataResultFormat.CSV:
include_index = not isinstance(df.index, pd.RangeIndex)
result = csv.df_to_escaped_csv(
df, index=include_index, **config["CSV_EXPORT"]
@@ -282,9 +348,12 @@ class QueryContext:
# Get all the payloads from the QueryObjects
query_results = [
get_query_results(
- query_obj.result_type or self.result_type, self, query_obj, force_cached
+ query_obj.result_type or self._query_context.result_type,
+ self._query_context,
+ query_obj,
+ force_cached,
)
- for query_obj in self.queries
+ for query_obj in self._query_context.queries
]
return_value = {"queries": query_results}
@@ -293,24 +362,17 @@ class QueryContext:
set_and_log_cache(
cache_manager.cache,
cache_key,
- {"data": self.cache_values},
- self.cache_timeout,
+ {"data": self._query_context.cache_values},
+ self.get_cache_timeout(),
)
return_value["cache_key"] = cache_key # type: ignore
return return_value
- @property
- def cache_timeout(self) -> int:
- if self.custom_cache_timeout is not None:
- return self.custom_cache_timeout
- if self.datasource.cache_timeout is not None:
- return self.datasource.cache_timeout
- if (
- hasattr(self.datasource, "database")
- and self.datasource.database.cache_timeout
- ) is not None:
- return self.datasource.database.cache_timeout
+ def get_cache_timeout(self) -> int:
+ cache_timeout_rv = self._query_context.get_cache_timeout()
+ if cache_timeout_rv:
+ return cache_timeout_rv
return config["CACHE_DEFAULT_TIMEOUT"]
def cache_key(self, **extra: Any) -> str:
@@ -320,32 +382,28 @@ class QueryContext:
required to rehydrate a QueryContext object.
"""
key_prefix = "qc-"
- cache_dict = self.cache_values.copy()
+ cache_dict = self._query_context.cache_values.copy()
cache_dict.update(extra)
return generate_cache_key(cache_dict, key_prefix)
- def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> Optional[str]:
+ def get_annotation_data(self, query_obj: QueryObject) -> Dict[str, Any]:
"""
- Returns a QueryObject cache key for objects in self.queries
+ :param query_context:
+ :param query_obj:
+ :return:
"""
- extra_cache_keys = self.datasource.get_extra_cache_keys(query_obj.to_dict())
-
- cache_key = (
- query_obj.cache_key(
- datasource=self.datasource.uid,
- extra_cache_keys=extra_cache_keys,
- rls=security_manager.get_rls_ids(self.datasource)
- if is_feature_enabled("ROW_LEVEL_SECURITY")
- and self.datasource.is_rls_supported
- else [],
- changed_on=self.datasource.changed_on,
- **kwargs,
+ annotation_data: Dict[str, Any] = self.get_native_annotation_data(query_obj)
+ for annotation_layer in [
+ layer
+ for layer in query_obj.annotation_layers
+ if layer["sourceType"] in ("line", "table")
+ ]:
+ name = annotation_layer["name"]
+ annotation_data[name] = self.get_viz_annotation_data(
+ annotation_layer, self._query_context.force
)
- if query_obj
- else None
- )
- return cache_key
+ return annotation_data
@staticmethod
def get_native_annotation_data(query_obj: QueryObject) -> Dict[str, Any]:
@@ -401,84 +459,12 @@ class QueryContext:
except SupersetException as ex:
raise QueryObjectValidationError(error_msg_from_exception(ex)) from ex
- def get_annotation_data(self, query_obj: QueryObject) -> Dict[str, Any]:
- """
-
- :param query_obj:
- :return:
- """
- annotation_data: Dict[str, Any] = self.get_native_annotation_data(query_obj)
- for annotation_layer in [
- layer
- for layer in query_obj.annotation_layers
- if layer["sourceType"] in ("line", "table")
- ]:
- name = annotation_layer["name"]
- annotation_data[name] = self.get_viz_annotation_data(
- annotation_layer, self.force
- )
- return annotation_data
-
- def get_df_payload(
- self, query_obj: QueryObject, force_cached: Optional[bool] = False,
- ) -> Dict[str, Any]:
- """Handles caching around the df payload retrieval"""
- cache_key = self.query_cache_key(query_obj)
- cache = QueryCacheManager.get(
- cache_key, CacheRegion.DATA, self.force, force_cached,
- )
-
- if query_obj and cache_key and not cache.is_loaded:
- try:
- invalid_columns = [
- col
- for col in get_column_names_from_columns(query_obj.columns)
- + get_column_names_from_metrics(query_obj.metrics or [])
- if col not in self.datasource.column_names and col != DTTM_ALIAS
- ]
- if invalid_columns:
- raise QueryObjectValidationError(
- _(
- "Columns missing in datasource: %(invalid_columns)s",
- invalid_columns=invalid_columns,
- )
- )
- query_result = self.get_query_result(query_obj)
- annotation_data = self.get_annotation_data(query_obj)
- cache.set_query_result(
- key=cache_key,
- query_result=query_result,
- annotation_data=annotation_data,
- force_query=self.force,
- timeout=self.cache_timeout,
- datasource_uid=self.datasource.uid,
- region=CacheRegion.DATA,
- )
- except QueryObjectValidationError as ex:
- cache.error_message = str(ex)
- cache.status = QueryStatus.FAILED
-
- return {
- "cache_key": cache_key,
- "cached_dttm": cache.cache_dttm,
- "cache_timeout": self.cache_timeout,
- "df": cache.df,
- "applied_template_filters": cache.applied_template_filters,
- "annotation_data": cache.annotation_data,
- "error": cache.error_message,
- "is_cached": cache.is_cached,
- "query": cache.query,
- "status": cache.status,
- "stacktrace": cache.stacktrace,
- "rowcount": len(cache.df.index),
- }
-
def raise_for_access(self) -> None:
"""
Raise an exception if the user cannot access the resource.
:raises SupersetSecurityException: If the user cannot access the resource
"""
- for query in self.queries:
+ for query in self._query_context.queries:
query.validate()
- security_manager.raise_for_access(query_context=self)
+ security_manager.raise_for_access(query_context=self._query_context)
diff --git a/tests/integration_tests/charts/schema_tests.py b/tests/integration_tests/charts/schema_tests.py
index 95e1c07..821a8d5 100644
--- a/tests/integration_tests/charts/schema_tests.py
+++ b/tests/integration_tests/charts/schema_tests.py
@@ -32,7 +32,8 @@ from tests.integration_tests.fixtures.query_context import get_query_context
class TestSchema(SupersetTestCase):
@mock.patch(
- "superset.common.query_context.config", {**app.config, "ROW_LIMIT": 5000},
+ "superset.common.query_context_factory.config",
+ {**app.config, "ROW_LIMIT": 5000},
)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_query_context_limit_and_offset(self):