You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by vi...@apache.org on 2021/05/21 16:37:59 UTC

[superset] branch master updated: feat: Add a remove filter_flag to jinja filter_values function (#14507)

This is an automated email from the ASF dual-hosted git repository.

villebro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 590fe20  feat: Add a remove filter_flag to jinja filter_values function (#14507)
590fe20 is described below

commit 590fe20a450d49695f1ad48161e7119a705a7cf5
Author: cccs-jc <56...@users.noreply.github.com>
AuthorDate: Fri May 21 12:37:09 2021 -0400

    feat: Add a remove filter_flag to jinja filter_values function (#14507)
    
    Implementation issue 13943
    
    Co-authored-by: cccs-jc <cc...@cyber.gc.ca>
---
 RESOURCES/FEATURE_FLAGS.md         |   1 +
 superset/config.py                 |   1 +
 superset/connectors/sqla/models.py |   8 ++
 superset/jinja_context.py          | 218 ++++++++++++++++++++++++++++---------
 tests/jinja_context_tests.py       | 119 +++++++++++++++++---
 5 files changed, 278 insertions(+), 69 deletions(-)

diff --git a/RESOURCES/FEATURE_FLAGS.md b/RESOURCES/FEATURE_FLAGS.md
index 47958a7..5461db7 100644
--- a/RESOURCES/FEATURE_FLAGS.md
+++ b/RESOURCES/FEATURE_FLAGS.md
@@ -33,6 +33,7 @@ These features are considered **unfinished** and should only be used on developm
 - REMOVE_SLICE_LEVEL_LABEL_COLORS
 - SHARE_QUERIES_VIA_KV_STORE
 - TAGGING_SYSTEM
+- ENABLE_TEMPLATE_REMOVE_FILTERS
 
 ## In Testing
 These features are **finished** but currently being tested. They are usable, but may still contain some bugs.
diff --git a/superset/config.py b/superset/config.py
index 9863f76..7f26ac2 100644
--- a/superset/config.py
+++ b/superset/config.py
@@ -334,6 +334,7 @@ DEFAULT_FEATURE_FLAGS: Dict[str, bool] = {
     # See `PR 7935 <https://github.com/apache/superset/pull/7935>`_ for more details.
     "ENABLE_EXPLORE_JSON_CSRF_PROTECTION": False,
     "ENABLE_TEMPLATE_PROCESSING": False,
+    "ENABLE_TEMPLATE_REMOVE_FILTERS": False,
     "KV_STORE": False,
     # When this feature is enabled, nested types in Presto will be
     # expanded into extra columns and/or arrays. This is experimental,
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index 1b4343c..7f62d4c 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -985,6 +985,8 @@ class SqlaTable(  # pylint: disable=too-many-public-methods,too-many-instance-at
         template_kwargs.update(self.template_params_dict)
         extra_cache_keys: List[Any] = []
         template_kwargs["extra_cache_keys"] = extra_cache_keys
+        removed_filters: List[str] = []
+        template_kwargs["removed_filters"] = removed_filters
         template_processor = self.get_template_processor(**template_kwargs)
         db_engine_spec = self.db_engine_spec
         prequeries: List[str] = []
@@ -1168,6 +1170,12 @@ class SqlaTable(  # pylint: disable=too-many-public-methods,too-many-instance-at
             val = flt.get("val")
             op = flt["op"].upper()
             col_obj = columns_by_name.get(col)
+
+            if is_feature_enabled("ENABLE_TEMPLATE_REMOVE_FILTERS"):
+                if col in removed_filters:
+                    # Skip generating SQLA filter when the jinja template handles it.
+                    continue
+
             if col_obj:
                 col_spec = db_engine_spec.get_column_spec(col_obj.type)
                 is_list_target = op in (
diff --git a/superset/jinja_context.py b/superset/jinja_context.py
index 54fc09d..b6fb855 100644
--- a/superset/jinja_context.py
+++ b/superset/jinja_context.py
@@ -18,12 +18,23 @@
 import json
 import re
 from functools import partial
-from typing import Any, Callable, cast, Dict, List, Optional, Tuple, TYPE_CHECKING
+from typing import (
+    Any,
+    Callable,
+    cast,
+    Dict,
+    List,
+    Optional,
+    Tuple,
+    TYPE_CHECKING,
+    Union,
+)
 
 from flask import current_app, g, request
 from flask_babel import gettext as _
 from jinja2 import DebugUndefined
 from jinja2.sandbox import SandboxedEnvironment
+from typing_extensions import TypedDict
 
 from superset.exceptions import SupersetTemplateException
 from superset.extensions import feature_flag_manager
@@ -60,56 +71,10 @@ def context_addons() -> Dict[str, Any]:
     return current_app.config.get("JINJA_CONTEXT_ADDONS", {})
 
 
-def filter_values(column: str, default: Optional[str] = None) -> List[str]:
-    """Gets a values for a particular filter as a list
-
-    This is useful if:
-        - you want to use a filter box to filter a query where the name of filter box
-          column doesn't match the one in the select statement
-        - you want to have the ability for filter inside the main query for speed
-          purposes
-
-    Usage example::
-
-        SELECT action, count(*) as times
-        FROM logs
-        WHERE action in ( {{ "'" + "','".join(filter_values('action_type')) + "'" }} )
-        GROUP BY action
-
-    :param column: column/filter name to lookup
-    :param default: default value to return if there's no matching columns
-    :return: returns a list of filter values
-    """
-
-    from superset.views.utils import get_form_data
-
-    form_data, _ = get_form_data()
-    convert_legacy_filters_into_adhoc(form_data)
-    merge_extra_filters(form_data)
-
-    return_val = [
-        comparator
-        for filter in form_data.get("adhoc_filters", [])
-        for comparator in (
-            filter["comparator"]
-            if isinstance(filter["comparator"], list)
-            else [filter["comparator"]]
-        )
-        if (
-            filter.get("expressionType") == "SIMPLE"
-            and filter.get("clause") == "WHERE"
-            and filter.get("subject") == column
-            and filter.get("comparator")
-        )
-    ]
-
-    if return_val:
-        return return_val
-
-    if default:
-        return [default]
-
-    return []
+class Filter(TypedDict):
+    op: str  # pylint: disable=C0103
+    col: str
+    val: Union[None, Any, List[Any]]
 
 
 class ExtraCache:
@@ -129,8 +94,13 @@ class ExtraCache:
         r").*\}\}"
     )
 
-    def __init__(self, extra_cache_keys: Optional[List[Any]] = None):
+    def __init__(
+        self,
+        extra_cache_keys: Optional[List[Any]] = None,
+        removed_filters: Optional[List[str]] = None,
+    ):
         self.extra_cache_keys = extra_cache_keys
+        self.removed_filters = removed_filters if removed_filters is not None else []
 
     def current_user_id(self, add_to_cache_keys: bool = True) -> Optional[int]:
         """
@@ -213,6 +183,142 @@ class ExtraCache:
             self.cache_key_wrapper(result)
         return result
 
+    def filter_values(
+        self, column: str, default: Optional[str] = None, remove_filter: bool = False
+    ) -> List[Any]:
+        """Gets a values for a particular filter as a list
+
+        This is useful if:
+            - you want to use a filter component to filter a query where the name of
+             filter component column doesn't match the one in the select statement
+            - you want to have the ability for filter inside the main query for speed
+            purposes
+
+        Usage example::
+
+            SELECT action, count(*) as times
+            FROM logs
+            WHERE
+                action in ({{ "'" + "','".join(filter_values('action_type')) + "'" }})
+            GROUP BY action
+
+        :param column: column/filter name to lookup
+        :param default: default value to return if there's no matching columns
+        :param remove_filter: When set to true, mark the filter as processed,
+            removing it from the outer query. Useful when a filter should
+            only apply to the inner query
+        :return: returns a list of filter values
+        """
+        return_val: List[Any] = []
+        filters = self.get_filters(column, remove_filter)
+        for flt in filters:
+            val = flt.get("val")
+            if isinstance(val, list):
+                return_val.extend(val)
+            elif val:
+                return_val.append(val)
+
+        if (not return_val) and default:
+            # If no values are found, return the default provided.
+            return_val = [default]
+
+        return return_val
+
+    def get_filters(self, column: str, remove_filter: bool = False) -> List[Filter]:
+        """Get the filters applied to the given column. In addition
+           to returning values like the filter_values function
+           the get_filters function returns the operator specified in the explorer UI.
+
+        This is useful if:
+            - you want to handle more than the IN operator in your SQL clause
+            - you want to handle generating custom SQL conditions for a filter
+            - you want to have the ability for filter inside the main query for speed
+            purposes
+
+        Usage example::
+
+
+            WITH RECURSIVE
+                superiors(employee_id, manager_id, full_name, level, lineage) AS (
+                SELECT
+                    employee_id,
+                    manager_id,
+                    full_name,
+                1 as level,
+                employee_id as lineage
+                FROM
+                    employees
+                WHERE
+                1=1
+                {# Render a blank line #}
+                {%- for filter in get_filters('full_name', remove_filter=True) -%}
+                {%- if filter.get('op') == 'IN' -%}
+                    AND
+                    full_name IN ( {{ "'" + "', '".join(filter.get('val')) + "'" }} )
+                {%- endif -%}
+                {%- if filter.get('op') == 'LIKE' -%}
+                    AND
+                    full_name LIKE {{ "'" + filter.get('val') + "'" }}
+                {%- endif -%}
+                {%- endfor -%}
+                UNION ALL
+                    SELECT
+                        e.employee_id,
+                        e.manager_id,
+                        e.full_name,
+                s.level + 1 as level,
+                s.lineage
+                    FROM
+                        employees e,
+                    superiors s
+                    WHERE s.manager_id = e.employee_id
+            )
+
+
+            SELECT
+                employee_id, manager_id, full_name, level, lineage
+            FROM
+                superiors
+            order by lineage, level
+
+        :param column: column/filter name to lookup
+        :param remove_filter: When set to true, mark the filter as processed,
+            removing it from the outer query. Useful when a filter should
+            only apply to the inner query
+        :return: returns a list of filters
+        """
+        from superset.utils.core import FilterOperator
+        from superset.views.utils import get_form_data
+
+        form_data, _ = get_form_data()
+        convert_legacy_filters_into_adhoc(form_data)
+        merge_extra_filters(form_data)
+
+        filters: List[Filter] = []
+
+        for flt in form_data.get("adhoc_filters", []):
+            val: Union[str, List[str]] = flt.get("comparator")
+            op: str = flt["operator"].upper() if "operator" in flt else None
+            # fltOpName: str = flt.get("filterOptionName")
+            if (
+                flt.get("expressionType") == "SIMPLE"
+                and flt.get("clause") == "WHERE"
+                and flt.get("subject") == column
+                and val
+            ):
+                if remove_filter:
+                    if column not in self.removed_filters:
+                        self.removed_filters.append(column)
+                if op in (
+                    FilterOperator.IN.value,
+                    FilterOperator.NOT_IN.value,
+                ) and not isinstance(val, list):
+                    val = [val]
+
+                filters.append({"op": op, "col": column, "val": val})
+
+        return filters
+
 
 def safe_proxy(func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
     return_value = func(*args, **kwargs)
@@ -280,12 +386,14 @@ class BaseTemplateProcessor:  # pylint: disable=too-few-public-methods
 
     engine: Optional[str] = None
 
+    # pylint: disable=too-many-arguments
     def __init__(
         self,
         database: "Database",
         query: Optional["Query"] = None,
         table: Optional["SqlaTable"] = None,
         extra_cache_keys: Optional[List[Any]] = None,
+        removed_filters: Optional[List[str]] = None,
         **kwargs: Any,
     ) -> None:
         self._database = database
@@ -296,6 +404,7 @@ class BaseTemplateProcessor:  # pylint: disable=too-few-public-methods
         elif table:
             self._schema = table.schema
         self._extra_cache_keys = extra_cache_keys
+        self._removed_filters = removed_filters
         self._context: Dict[str, Any] = {}
         self._env = SandboxedEnvironment(undefined=DebugUndefined)
         self.set_context(**kwargs)
@@ -321,14 +430,15 @@ class BaseTemplateProcessor:  # pylint: disable=too-few-public-methods
 class JinjaTemplateProcessor(BaseTemplateProcessor):
     def set_context(self, **kwargs: Any) -> None:
         super().set_context(**kwargs)
-        extra_cache = ExtraCache(self._extra_cache_keys)
+        extra_cache = ExtraCache(self._extra_cache_keys, self._removed_filters)
         self._context.update(
             {
                 "url_param": partial(safe_proxy, extra_cache.url_param),
                 "current_user_id": partial(safe_proxy, extra_cache.current_user_id),
                 "current_username": partial(safe_proxy, extra_cache.current_username),
                 "cache_key_wrapper": partial(safe_proxy, extra_cache.cache_key_wrapper),
-                "filter_values": partial(safe_proxy, filter_values),
+                "filter_values": partial(safe_proxy, extra_cache.filter_values),
+                "get_filters": partial(safe_proxy, extra_cache.get_filters),
             }
         )
 
diff --git a/tests/jinja_context_tests.py b/tests/jinja_context_tests.py
index a3314d0..0628bbe 100644
--- a/tests/jinja_context_tests.py
+++ b/tests/jinja_context_tests.py
@@ -24,12 +24,7 @@ import pytest
 import tests.test_app
 from superset import app
 from superset.exceptions import SupersetTemplateException
-from superset.jinja_context import (
-    ExtraCache,
-    filter_values,
-    get_template_processor,
-    safe_proxy,
-)
+from superset.jinja_context import ExtraCache, get_template_processor, safe_proxy
 from superset.utils import core as utils
 from tests.base_tests import SupersetTestCase
 
@@ -37,11 +32,26 @@ from tests.base_tests import SupersetTestCase
 class TestJinja2Context(SupersetTestCase):
     def test_filter_values_default(self) -> None:
         with app.test_request_context():
-            self.assertEqual(filter_values("name", "foo"), ["foo"])
+            cache = ExtraCache()
+            self.assertEqual(cache.filter_values("name", "foo"), ["foo"])
+            self.assertEqual(cache.removed_filters, list())
+
+    def test_filter_values_remove_not_present(self) -> None:
+        with app.test_request_context():
+            cache = ExtraCache()
+            self.assertEqual(cache.filter_values("name", remove_filter=True), [])
+            self.assertEqual(cache.removed_filters, list())
+
+    def test_get_filters_remove_not_present(self) -> None:
+        with app.test_request_context():
+            cache = ExtraCache()
+            self.assertEqual(cache.get_filters("name", remove_filter=True), [])
+            self.assertEqual(cache.removed_filters, list())
 
     def test_filter_values_no_default(self) -> None:
         with app.test_request_context():
-            self.assertEqual(filter_values("name"), [])
+            cache = ExtraCache()
+            self.assertEqual(cache.filter_values("name"), [])
 
     def test_filter_values_adhoc_filters(self) -> None:
         with app.test_request_context(
@@ -61,7 +71,76 @@ class TestJinja2Context(SupersetTestCase):
                 )
             }
         ):
-            self.assertEqual(filter_values("name"), ["foo"])
+            cache = ExtraCache()
+            self.assertEqual(cache.filter_values("name"), ["foo"])
+
+        with app.test_request_context(
+            data={
+                "form_data": json.dumps(
+                    {
+                        "adhoc_filters": [
+                            {
+                                "clause": "WHERE",
+                                "comparator": ["foo", "bar"],
+                                "expressionType": "SIMPLE",
+                                "operator": "in",
+                                "subject": "name",
+                            }
+                        ],
+                    }
+                )
+            }
+        ):
+            cache = ExtraCache()
+            self.assertEqual(cache.filter_values("name"), ["foo", "bar"])
+
+    def test_get_filters_adhoc_filters(self) -> None:
+        with app.test_request_context(
+            data={
+                "form_data": json.dumps(
+                    {
+                        "adhoc_filters": [
+                            {
+                                "clause": "WHERE",
+                                "comparator": "foo",
+                                "expressionType": "SIMPLE",
+                                "operator": "in",
+                                "subject": "name",
+                            }
+                        ],
+                    }
+                )
+            }
+        ):
+            cache = ExtraCache()
+            self.assertEqual(
+                cache.get_filters("name"), [{"op": "IN", "col": "name", "val": ["foo"]}]
+            )
+            self.assertEqual(cache.removed_filters, list())
+
+        with app.test_request_context(
+            data={
+                "form_data": json.dumps(
+                    {
+                        "adhoc_filters": [
+                            {
+                                "clause": "WHERE",
+                                "comparator": ["foo", "bar"],
+                                "expressionType": "SIMPLE",
+                                "operator": "in",
+                                "subject": "name",
+                            }
+                        ],
+                    }
+                )
+            }
+        ):
+            cache = ExtraCache()
+            self.assertEqual(
+                cache.get_filters("name"),
+                [{"op": "IN", "col": "name", "val": ["foo", "bar"]}],
+            )
+            self.assertEqual(cache.removed_filters, list())
 
         with app.test_request_context(
             data={
@@ -80,7 +159,12 @@ class TestJinja2Context(SupersetTestCase):
                 )
             }
         ):
-            self.assertEqual(filter_values("name"), ["foo", "bar"])
+            cache = ExtraCache()
+            self.assertEqual(
+                cache.get_filters("name", remove_filter=True),
+                [{"op": "IN", "col": "name", "val": ["foo", "bar"]}],
+            )
+            self.assertEqual(cache.removed_filters, ["name"])
 
     def test_filter_values_extra_filters(self) -> None:
         with app.test_request_context(
@@ -90,25 +174,30 @@ class TestJinja2Context(SupersetTestCase):
                 )
             }
         ):
-            self.assertEqual(filter_values("name"), ["foo"])
+            cache = ExtraCache()
+            self.assertEqual(cache.filter_values("name"), ["foo"])
 
     def test_url_param_default(self) -> None:
         with app.test_request_context():
-            self.assertEqual(ExtraCache().url_param("foo", "bar"), "bar")
+            cache = ExtraCache()
+            self.assertEqual(cache.url_param("foo", "bar"), "bar")
 
     def test_url_param_no_default(self) -> None:
         with app.test_request_context():
-            self.assertEqual(ExtraCache().url_param("foo"), None)
+            cache = ExtraCache()
+            self.assertEqual(cache.url_param("foo"), None)
 
     def test_url_param_query(self) -> None:
         with app.test_request_context(query_string={"foo": "bar"}):
-            self.assertEqual(ExtraCache().url_param("foo"), "bar")
+            cache = ExtraCache()
+            self.assertEqual(cache.url_param("foo"), "bar")
 
     def test_url_param_form_data(self) -> None:
         with app.test_request_context(
             query_string={"form_data": json.dumps({"url_params": {"foo": "bar"}})}
         ):
-            self.assertEqual(ExtraCache().url_param("foo"), "bar")
+            cache = ExtraCache()
+            self.assertEqual(cache.url_param("foo"), "bar")
 
     def test_safe_proxy_primitive(self) -> None:
         def func(input: Any) -> Any: