You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by yo...@apache.org on 2022/08/08 14:42:22 UTC

[superset] branch master updated: feat: supports mulitple filters in samples endpoint (#21008)

This is an automated email from the ASF dual-hosted git repository.

yongjiezhao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 802b69f97b feat: supports mulitple filters in samples endpoint (#21008)
802b69f97b is described below

commit 802b69f97bb9fd35fe8aed225cfd6a03875cf747
Author: Yongjie Zhao <yo...@gmail.com>
AuthorDate: Mon Aug 8 22:42:14 2022 +0800

    feat: supports mulitple filters in samples endpoint (#21008)
---
 superset/common/chart_data.py               |   1 +
 superset/common/query_actions.py            |  22 +++++
 superset/views/datasource/schemas.py        |  13 ++-
 superset/views/datasource/utils.py          |  35 +++++---
 tests/integration_tests/conftest.py         |  11 +--
 tests/integration_tests/datasource_tests.py | 132 +++++++++++++++++++---------
 6 files changed, 154 insertions(+), 60 deletions(-)

diff --git a/superset/common/chart_data.py b/superset/common/chart_data.py
index f3917d6d87..ea31d4f138 100644
--- a/superset/common/chart_data.py
+++ b/superset/common/chart_data.py
@@ -38,3 +38,4 @@ class ChartDataResultType(str, Enum):
     SAMPLES = "samples"
     TIMEGRAINS = "timegrains"
     POST_PROCESSED = "post_processed"
+    DRILL_DETAIL = "drill_detail"
diff --git a/superset/common/query_actions.py b/superset/common/query_actions.py
index 0764e19340..bfb3d36878 100644
--- a/superset/common/query_actions.py
+++ b/superset/common/query_actions.py
@@ -162,6 +162,27 @@ def _get_samples(
     return _get_full(query_context, query_obj, force_cached)
 
 
+def _get_drill_detail(
+    query_context: QueryContext, query_obj: QueryObject, force_cached: bool = False
+) -> Dict[str, Any]:
+    # todo(yongjie): Remove this function,
+    #  when determining whether samples should be applied to the time filter.
+    datasource = _get_datasource(query_context, query_obj)
+    query_obj = copy.copy(query_obj)
+    query_obj.is_timeseries = False
+    query_obj.orderby = []
+    query_obj.metrics = None
+    query_obj.post_processing = []
+    qry_obj_cols = []
+    for o in datasource.columns:
+        if isinstance(o, dict):
+            qry_obj_cols.append(o.get("column_name"))
+        else:
+            qry_obj_cols.append(o.column_name)
+    query_obj.columns = qry_obj_cols
+    return _get_full(query_context, query_obj, force_cached)
+
+
 def _get_results(
     query_context: QueryContext, query_obj: QueryObject, force_cached: bool = False
 ) -> Dict[str, Any]:
@@ -182,6 +203,7 @@ _result_type_functions: Dict[
     # and post-process it later where we have the chart context, since
     # post-processing is unique to each visualization type
     ChartDataResultType.POST_PROCESSED: _get_full,
+    ChartDataResultType.DRILL_DETAIL: _get_drill_detail,
 }
 
 
diff --git a/superset/views/datasource/schemas.py b/superset/views/datasource/schemas.py
index 4c97f17e88..f9be7a7d4e 100644
--- a/superset/views/datasource/schemas.py
+++ b/superset/views/datasource/schemas.py
@@ -20,7 +20,7 @@ from marshmallow import fields, post_load, pre_load, Schema, validate
 from typing_extensions import TypedDict
 
 from superset import app
-from superset.charts.schemas import ChartDataFilterSchema
+from superset.charts.schemas import ChartDataExtrasSchema, ChartDataFilterSchema
 from superset.utils.core import DatasourceType
 
 
@@ -62,6 +62,17 @@ class ExternalMetadataSchema(Schema):
 
 class SamplesPayloadSchema(Schema):
     filters = fields.List(fields.Nested(ChartDataFilterSchema), required=False)
+    granularity = fields.String(
+        allow_none=True,
+    )
+    time_range = fields.String(
+        allow_none=True,
+    )
+    extras = fields.Nested(
+        ChartDataExtrasSchema,
+        description="Extra parameters to add to the query.",
+        allow_none=True,
+    )
 
     @pre_load
     # pylint: disable=no-self-use, unused-argument
diff --git a/superset/views/datasource/utils.py b/superset/views/datasource/utils.py
index 0191db2947..42cddf4167 100644
--- a/superset/views/datasource/utils.py
+++ b/superset/views/datasource/utils.py
@@ -60,17 +60,30 @@ def get_samples(  # pylint: disable=too-many-arguments,too-many-locals
     limit_clause = get_limit_clause(page, per_page)
 
     # todo(yongjie): Constructing count(*) and samples in the same query_context,
-    #  then remove query_type==SAMPLES
-    # constructing samples query
-    samples_instance = QueryContextFactory().create(
-        datasource={
-            "type": datasource.type,
-            "id": datasource.id,
-        },
-        queries=[{**payload, **limit_clause} if payload else limit_clause],
-        result_type=ChartDataResultType.SAMPLES,
-        force=force,
-    )
+    if payload is None:
+        # constructing samples query
+        samples_instance = QueryContextFactory().create(
+            datasource={
+                "type": datasource.type,
+                "id": datasource.id,
+            },
+            queries=[limit_clause],
+            result_type=ChartDataResultType.SAMPLES,
+            force=force,
+        )
+    else:
+        # constructing drill detail query
+        # When query_type == 'samples' the `time filter` will be removed,
+        # so it is not applicable drill detail query
+        samples_instance = QueryContextFactory().create(
+            datasource={
+                "type": datasource.type,
+                "id": datasource.id,
+            },
+            queries=[{**payload, **limit_clause}],
+            result_type=ChartDataResultType.DRILL_DETAIL,
+            force=force,
+        )
 
     # constructing count(*) query
     count_star_metric = {
diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py
index 043d792219..549a987db1 100644
--- a/tests/integration_tests/conftest.py
+++ b/tests/integration_tests/conftest.py
@@ -314,7 +314,7 @@ def physical_dataset():
           col2 VARCHAR(255),
           col3 DECIMAL(4,2),
           col4 VARCHAR(255),
-          col5 VARCHAR(255)
+          col5 TIMESTAMP
         );
         """
     )
@@ -342,11 +342,10 @@ def physical_dataset():
     TableColumn(column_name="col2", type="VARCHAR(255)", table=dataset)
     TableColumn(column_name="col3", type="DECIMAL(4,2)", table=dataset)
     TableColumn(column_name="col4", type="VARCHAR(255)", table=dataset)
-    TableColumn(column_name="col5", type="VARCHAR(255)", table=dataset)
+    TableColumn(column_name="col5", type="TIMESTAMP", is_dttm=True, table=dataset)
     SqlMetric(metric_name="count", expression="count(*)", table=dataset)
     db.session.merge(dataset)
-    if example_database.backend == "sqlite":
-        db.session.commit()
+    db.session.commit()
 
     yield dataset
 
@@ -355,5 +354,7 @@ def physical_dataset():
         DROP TABLE physical_dataset;
     """
     )
-    db.session.delete(dataset)
+    dataset = db.session.query(SqlaTable).filter_by(table_name="physical_dataset").all()
+    for ds in dataset:
+        db.session.delete(ds)
     db.session.commit()
diff --git a/tests/integration_tests/datasource_tests.py b/tests/integration_tests/datasource_tests.py
index ad4d625cc5..ef3ba0c69d 100644
--- a/tests/integration_tests/datasource_tests.py
+++ b/tests/integration_tests/datasource_tests.py
@@ -432,14 +432,13 @@ def test_get_samples(test_client, login_as_admin, virtual_dataset):
     test_client.post(uri)
     # get from cache
     rv = test_client.post(uri)
-    rv_data = json.loads(rv.data)
     assert rv.status_code == 200
-    assert len(rv_data["result"]["data"]) == 10
+    assert len(rv.json["result"]["data"]) == 10
     assert QueryCacheManager.has(
-        rv_data["result"]["cache_key"],
+        rv.json["result"]["cache_key"],
         region=CacheRegion.DATA,
     )
-    assert rv_data["result"]["is_cached"]
+    assert rv.json["result"]["is_cached"]
 
     # 2. should read through cache data
     uri2 = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&force=true"
@@ -447,19 +446,18 @@ def test_get_samples(test_client, login_as_admin, virtual_dataset):
     test_client.post(uri2)
     # force query
     rv2 = test_client.post(uri2)
-    rv_data2 = json.loads(rv2.data)
     assert rv2.status_code == 200
-    assert len(rv_data2["result"]["data"]) == 10
+    assert len(rv2.json["result"]["data"]) == 10
     assert QueryCacheManager.has(
-        rv_data2["result"]["cache_key"],
+        rv2.json["result"]["cache_key"],
         region=CacheRegion.DATA,
     )
-    assert not rv_data2["result"]["is_cached"]
+    assert not rv2.json["result"]["is_cached"]
 
     # 3. data precision
-    assert "colnames" in rv_data2["result"]
-    assert "coltypes" in rv_data2["result"]
-    assert "data" in rv_data2["result"]
+    assert "colnames" in rv2.json["result"]
+    assert "coltypes" in rv2.json["result"]
+    assert "data" in rv2.json["result"]
 
     eager_samples = virtual_dataset.database.get_df(
         f"select * from ({virtual_dataset.sql}) as tbl"
@@ -468,7 +466,7 @@ def test_get_samples(test_client, login_as_admin, virtual_dataset):
     # the col3 is Decimal
     eager_samples["col3"] = eager_samples["col3"].apply(float)
     eager_samples = eager_samples.to_dict(orient="records")
-    assert eager_samples == rv_data2["result"]["data"]
+    assert eager_samples == rv2.json["result"]["data"]
 
 
 def test_get_samples_with_incorrect_cc(test_client, login_as_admin, virtual_dataset):
@@ -486,10 +484,9 @@ def test_get_samples_with_incorrect_cc(test_client, login_as_admin, virtual_data
     rv = test_client.post(uri)
     assert rv.status_code == 422
 
-    rv_data = json.loads(rv.data)
-    assert "error" in rv_data
+    assert "error" in rv.json
     if virtual_dataset.database.db_engine_spec.engine_name == "PostgreSQL":
-        assert "INCORRECT SQL" in rv_data.get("error")
+        assert "INCORRECT SQL" in rv.json.get("error")
 
 
 def test_get_samples_on_physical_dataset(test_client, login_as_admin, physical_dataset):
@@ -498,11 +495,10 @@ def test_get_samples_on_physical_dataset(test_client, login_as_admin, physical_d
     )
     rv = test_client.post(uri)
     assert rv.status_code == 200
-    rv_data = json.loads(rv.data)
     assert QueryCacheManager.has(
-        rv_data["result"]["cache_key"], region=CacheRegion.DATA
+        rv.json["result"]["cache_key"], region=CacheRegion.DATA
     )
-    assert len(rv_data["result"]["data"]) == 10
+    assert len(rv.json["result"]["data"]) == 10
 
 
 def test_get_samples_with_filters(test_client, login_as_admin, virtual_dataset):
@@ -533,9 +529,8 @@ def test_get_samples_with_filters(test_client, login_as_admin, virtual_dataset):
         },
     )
     assert rv.status_code == 200
-    rv_data = json.loads(rv.data)
-    assert rv_data["result"]["colnames"] == ["col1", "col2", "col3", "col4", "col5"]
-    assert rv_data["result"]["rowcount"] == 1
+    assert rv.json["result"]["colnames"] == ["col1", "col2", "col3", "col4", "col5"]
+    assert rv.json["result"]["rowcount"] == 1
 
     # empty results
     rv = test_client.post(
@@ -547,9 +542,64 @@ def test_get_samples_with_filters(test_client, login_as_admin, virtual_dataset):
         },
     )
     assert rv.status_code == 200
-    rv_data = json.loads(rv.data)
-    assert rv_data["result"]["colnames"] == []
-    assert rv_data["result"]["rowcount"] == 0
+    assert rv.json["result"]["colnames"] == []
+    assert rv.json["result"]["rowcount"] == 0
+
+
+def test_get_samples_with_time_filter(test_client, login_as_admin, physical_dataset):
+    uri = (
+        f"/datasource/samples?datasource_id={physical_dataset.id}&datasource_type=table"
+    )
+    payload = {
+        "granularity": "col5",
+        "time_range": "2000-01-02 : 2000-01-04",
+    }
+    rv = test_client.post(uri, json=payload)
+    assert len(rv.json["result"]["data"]) == 2
+    if physical_dataset.database.backend != "sqlite":
+        assert [row["col5"] for row in rv.json["result"]["data"]] == [
+            946771200000.0,  # 2000-01-02 00:00:00
+            946857600000.0,  # 2000-01-03 00:00:00
+        ]
+    assert rv.json["result"]["page"] == 1
+    assert rv.json["result"]["per_page"] == app.config["SAMPLES_ROW_LIMIT"]
+    assert rv.json["result"]["total_count"] == 2
+
+
+def test_get_samples_with_multiple_filters(
+    test_client, login_as_admin, physical_dataset
+):
+    # 1. empty response
+    uri = (
+        f"/datasource/samples?datasource_id={physical_dataset.id}&datasource_type=table"
+    )
+    payload = {
+        "granularity": "col5",
+        "time_range": "2000-01-02 : 2000-01-04",
+        "filters": [
+            {"col": "col4", "op": "IS NOT NULL"},
+        ],
+    }
+    rv = test_client.post(uri, json=payload)
+    assert len(rv.json["result"]["data"]) == 0
+
+    # 2. adhoc filters, time filters, and custom where
+    payload = {
+        "granularity": "col5",
+        "time_range": "2000-01-02 : 2000-01-04",
+        "filters": [
+            {"col": "col2", "op": "==", "val": "c"},
+        ],
+        "extras": {"where": "col3 = 1.2 and col4 is null"},
+    }
+    rv = test_client.post(uri, json=payload)
+    assert len(rv.json["result"]["data"]) == 1
+    assert rv.json["result"]["total_count"] == 1
+    assert "2000-01-02" in rv.json["result"]["query"]
+    assert "2000-01-04" in rv.json["result"]["query"]
+    assert "col3 = 1.2" in rv.json["result"]["query"]
+    assert "col4 is null" in rv.json["result"]["query"]
+    assert "col2 = 'c'" in rv.json["result"]["query"]
 
 
 def test_get_samples_pagination(test_client, login_as_admin, virtual_dataset):
@@ -558,10 +608,9 @@ def test_get_samples_pagination(test_client, login_as_admin, virtual_dataset):
         f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table"
     )
     rv = test_client.post(uri)
-    rv_data = json.loads(rv.data)
-    assert rv_data["result"]["page"] == 1
-    assert rv_data["result"]["per_page"] == app.config["SAMPLES_ROW_LIMIT"]
-    assert rv_data["result"]["total_count"] == 10
+    assert rv.json["result"]["page"] == 1
+    assert rv.json["result"]["per_page"] == app.config["SAMPLES_ROW_LIMIT"]
+    assert rv.json["result"]["total_count"] == 10
 
     # 2. incorrect per_page
     per_pages = (app.config["SAMPLES_ROW_LIMIT"] + 1, 0, "xx")
@@ -582,25 +631,22 @@ def test_get_samples_pagination(test_client, login_as_admin, virtual_dataset):
     # 4. turning pages
     uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=1"
     rv = test_client.post(uri)
-    rv_data = json.loads(rv.data)
-    assert rv_data["result"]["page"] == 1
-    assert rv_data["result"]["per_page"] == 2
-    assert rv_data["result"]["total_count"] == 10
-    assert [row["col1"] for row in rv_data["result"]["data"]] == [0, 1]
+    assert rv.json["result"]["page"] == 1
+    assert rv.json["result"]["per_page"] == 2
+    assert rv.json["result"]["total_count"] == 10
+    assert [row["col1"] for row in rv.json["result"]["data"]] == [0, 1]
 
     uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=2"
     rv = test_client.post(uri)
-    rv_data = json.loads(rv.data)
-    assert rv_data["result"]["page"] == 2
-    assert rv_data["result"]["per_page"] == 2
-    assert rv_data["result"]["total_count"] == 10
-    assert [row["col1"] for row in rv_data["result"]["data"]] == [2, 3]
+    assert rv.json["result"]["page"] == 2
+    assert rv.json["result"]["per_page"] == 2
+    assert rv.json["result"]["total_count"] == 10
+    assert [row["col1"] for row in rv.json["result"]["data"]] == [2, 3]
 
     # 5. Exceeding the maximum pages
     uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=6"
     rv = test_client.post(uri)
-    rv_data = json.loads(rv.data)
-    assert rv_data["result"]["page"] == 6
-    assert rv_data["result"]["per_page"] == 2
-    assert rv_data["result"]["total_count"] == 10
-    assert [row["col1"] for row in rv_data["result"]["data"]] == []
+    assert rv.json["result"]["page"] == 6
+    assert rv.json["result"]["per_page"] == 2
+    assert rv.json["result"]["total_count"] == 10
+    assert [row["col1"] for row in rv.json["result"]["data"]] == []