You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by yo...@apache.org on 2022/08/08 14:42:22 UTC
[superset] branch master updated: feat: supports mulitple filters in samples endpoint (#21008)
This is an automated email from the ASF dual-hosted git repository.
yongjiezhao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new 802b69f97b feat: supports mulitple filters in samples endpoint (#21008)
802b69f97b is described below
commit 802b69f97bb9fd35fe8aed225cfd6a03875cf747
Author: Yongjie Zhao <yo...@gmail.com>
AuthorDate: Mon Aug 8 22:42:14 2022 +0800
feat: supports mulitple filters in samples endpoint (#21008)
---
superset/common/chart_data.py | 1 +
superset/common/query_actions.py | 22 +++++
superset/views/datasource/schemas.py | 13 ++-
superset/views/datasource/utils.py | 35 +++++---
tests/integration_tests/conftest.py | 11 +--
tests/integration_tests/datasource_tests.py | 132 +++++++++++++++++++---------
6 files changed, 154 insertions(+), 60 deletions(-)
diff --git a/superset/common/chart_data.py b/superset/common/chart_data.py
index f3917d6d87..ea31d4f138 100644
--- a/superset/common/chart_data.py
+++ b/superset/common/chart_data.py
@@ -38,3 +38,4 @@ class ChartDataResultType(str, Enum):
SAMPLES = "samples"
TIMEGRAINS = "timegrains"
POST_PROCESSED = "post_processed"
+ DRILL_DETAIL = "drill_detail"
diff --git a/superset/common/query_actions.py b/superset/common/query_actions.py
index 0764e19340..bfb3d36878 100644
--- a/superset/common/query_actions.py
+++ b/superset/common/query_actions.py
@@ -162,6 +162,27 @@ def _get_samples(
return _get_full(query_context, query_obj, force_cached)
+def _get_drill_detail(
+ query_context: QueryContext, query_obj: QueryObject, force_cached: bool = False
+) -> Dict[str, Any]:
+ # todo(yongjie): Remove this function,
+ # when determining whether samples should be applied to the time filter.
+ datasource = _get_datasource(query_context, query_obj)
+ query_obj = copy.copy(query_obj)
+ query_obj.is_timeseries = False
+ query_obj.orderby = []
+ query_obj.metrics = None
+ query_obj.post_processing = []
+ qry_obj_cols = []
+ for o in datasource.columns:
+ if isinstance(o, dict):
+ qry_obj_cols.append(o.get("column_name"))
+ else:
+ qry_obj_cols.append(o.column_name)
+ query_obj.columns = qry_obj_cols
+ return _get_full(query_context, query_obj, force_cached)
+
+
def _get_results(
query_context: QueryContext, query_obj: QueryObject, force_cached: bool = False
) -> Dict[str, Any]:
@@ -182,6 +203,7 @@ _result_type_functions: Dict[
# and post-process it later where we have the chart context, since
# post-processing is unique to each visualization type
ChartDataResultType.POST_PROCESSED: _get_full,
+ ChartDataResultType.DRILL_DETAIL: _get_drill_detail,
}
diff --git a/superset/views/datasource/schemas.py b/superset/views/datasource/schemas.py
index 4c97f17e88..f9be7a7d4e 100644
--- a/superset/views/datasource/schemas.py
+++ b/superset/views/datasource/schemas.py
@@ -20,7 +20,7 @@ from marshmallow import fields, post_load, pre_load, Schema, validate
from typing_extensions import TypedDict
from superset import app
-from superset.charts.schemas import ChartDataFilterSchema
+from superset.charts.schemas import ChartDataExtrasSchema, ChartDataFilterSchema
from superset.utils.core import DatasourceType
@@ -62,6 +62,17 @@ class ExternalMetadataSchema(Schema):
class SamplesPayloadSchema(Schema):
filters = fields.List(fields.Nested(ChartDataFilterSchema), required=False)
+ granularity = fields.String(
+ allow_none=True,
+ )
+ time_range = fields.String(
+ allow_none=True,
+ )
+ extras = fields.Nested(
+ ChartDataExtrasSchema,
+ description="Extra parameters to add to the query.",
+ allow_none=True,
+ )
@pre_load
# pylint: disable=no-self-use, unused-argument
diff --git a/superset/views/datasource/utils.py b/superset/views/datasource/utils.py
index 0191db2947..42cddf4167 100644
--- a/superset/views/datasource/utils.py
+++ b/superset/views/datasource/utils.py
@@ -60,17 +60,30 @@ def get_samples( # pylint: disable=too-many-arguments,too-many-locals
limit_clause = get_limit_clause(page, per_page)
# todo(yongjie): Constructing count(*) and samples in the same query_context,
- # then remove query_type==SAMPLES
- # constructing samples query
- samples_instance = QueryContextFactory().create(
- datasource={
- "type": datasource.type,
- "id": datasource.id,
- },
- queries=[{**payload, **limit_clause} if payload else limit_clause],
- result_type=ChartDataResultType.SAMPLES,
- force=force,
- )
+ if payload is None:
+ # constructing samples query
+ samples_instance = QueryContextFactory().create(
+ datasource={
+ "type": datasource.type,
+ "id": datasource.id,
+ },
+ queries=[limit_clause],
+ result_type=ChartDataResultType.SAMPLES,
+ force=force,
+ )
+ else:
+ # constructing drill detail query
+ # When query_type == 'samples' the `time filter` will be removed,
+ # so it is not applicable drill detail query
+ samples_instance = QueryContextFactory().create(
+ datasource={
+ "type": datasource.type,
+ "id": datasource.id,
+ },
+ queries=[{**payload, **limit_clause}],
+ result_type=ChartDataResultType.DRILL_DETAIL,
+ force=force,
+ )
# constructing count(*) query
count_star_metric = {
diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py
index 043d792219..549a987db1 100644
--- a/tests/integration_tests/conftest.py
+++ b/tests/integration_tests/conftest.py
@@ -314,7 +314,7 @@ def physical_dataset():
col2 VARCHAR(255),
col3 DECIMAL(4,2),
col4 VARCHAR(255),
- col5 VARCHAR(255)
+ col5 TIMESTAMP
);
"""
)
@@ -342,11 +342,10 @@ def physical_dataset():
TableColumn(column_name="col2", type="VARCHAR(255)", table=dataset)
TableColumn(column_name="col3", type="DECIMAL(4,2)", table=dataset)
TableColumn(column_name="col4", type="VARCHAR(255)", table=dataset)
- TableColumn(column_name="col5", type="VARCHAR(255)", table=dataset)
+ TableColumn(column_name="col5", type="TIMESTAMP", is_dttm=True, table=dataset)
SqlMetric(metric_name="count", expression="count(*)", table=dataset)
db.session.merge(dataset)
- if example_database.backend == "sqlite":
- db.session.commit()
+ db.session.commit()
yield dataset
@@ -355,5 +354,7 @@ def physical_dataset():
DROP TABLE physical_dataset;
"""
)
- db.session.delete(dataset)
+ dataset = db.session.query(SqlaTable).filter_by(table_name="physical_dataset").all()
+ for ds in dataset:
+ db.session.delete(ds)
db.session.commit()
diff --git a/tests/integration_tests/datasource_tests.py b/tests/integration_tests/datasource_tests.py
index ad4d625cc5..ef3ba0c69d 100644
--- a/tests/integration_tests/datasource_tests.py
+++ b/tests/integration_tests/datasource_tests.py
@@ -432,14 +432,13 @@ def test_get_samples(test_client, login_as_admin, virtual_dataset):
test_client.post(uri)
# get from cache
rv = test_client.post(uri)
- rv_data = json.loads(rv.data)
assert rv.status_code == 200
- assert len(rv_data["result"]["data"]) == 10
+ assert len(rv.json["result"]["data"]) == 10
assert QueryCacheManager.has(
- rv_data["result"]["cache_key"],
+ rv.json["result"]["cache_key"],
region=CacheRegion.DATA,
)
- assert rv_data["result"]["is_cached"]
+ assert rv.json["result"]["is_cached"]
# 2. should read through cache data
uri2 = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&force=true"
@@ -447,19 +446,18 @@ def test_get_samples(test_client, login_as_admin, virtual_dataset):
test_client.post(uri2)
# force query
rv2 = test_client.post(uri2)
- rv_data2 = json.loads(rv2.data)
assert rv2.status_code == 200
- assert len(rv_data2["result"]["data"]) == 10
+ assert len(rv2.json["result"]["data"]) == 10
assert QueryCacheManager.has(
- rv_data2["result"]["cache_key"],
+ rv2.json["result"]["cache_key"],
region=CacheRegion.DATA,
)
- assert not rv_data2["result"]["is_cached"]
+ assert not rv2.json["result"]["is_cached"]
# 3. data precision
- assert "colnames" in rv_data2["result"]
- assert "coltypes" in rv_data2["result"]
- assert "data" in rv_data2["result"]
+ assert "colnames" in rv2.json["result"]
+ assert "coltypes" in rv2.json["result"]
+ assert "data" in rv2.json["result"]
eager_samples = virtual_dataset.database.get_df(
f"select * from ({virtual_dataset.sql}) as tbl"
@@ -468,7 +466,7 @@ def test_get_samples(test_client, login_as_admin, virtual_dataset):
# the col3 is Decimal
eager_samples["col3"] = eager_samples["col3"].apply(float)
eager_samples = eager_samples.to_dict(orient="records")
- assert eager_samples == rv_data2["result"]["data"]
+ assert eager_samples == rv2.json["result"]["data"]
def test_get_samples_with_incorrect_cc(test_client, login_as_admin, virtual_dataset):
@@ -486,10 +484,9 @@ def test_get_samples_with_incorrect_cc(test_client, login_as_admin, virtual_data
rv = test_client.post(uri)
assert rv.status_code == 422
- rv_data = json.loads(rv.data)
- assert "error" in rv_data
+ assert "error" in rv.json
if virtual_dataset.database.db_engine_spec.engine_name == "PostgreSQL":
- assert "INCORRECT SQL" in rv_data.get("error")
+ assert "INCORRECT SQL" in rv.json.get("error")
def test_get_samples_on_physical_dataset(test_client, login_as_admin, physical_dataset):
@@ -498,11 +495,10 @@ def test_get_samples_on_physical_dataset(test_client, login_as_admin, physical_d
)
rv = test_client.post(uri)
assert rv.status_code == 200
- rv_data = json.loads(rv.data)
assert QueryCacheManager.has(
- rv_data["result"]["cache_key"], region=CacheRegion.DATA
+ rv.json["result"]["cache_key"], region=CacheRegion.DATA
)
- assert len(rv_data["result"]["data"]) == 10
+ assert len(rv.json["result"]["data"]) == 10
def test_get_samples_with_filters(test_client, login_as_admin, virtual_dataset):
@@ -533,9 +529,8 @@ def test_get_samples_with_filters(test_client, login_as_admin, virtual_dataset):
},
)
assert rv.status_code == 200
- rv_data = json.loads(rv.data)
- assert rv_data["result"]["colnames"] == ["col1", "col2", "col3", "col4", "col5"]
- assert rv_data["result"]["rowcount"] == 1
+ assert rv.json["result"]["colnames"] == ["col1", "col2", "col3", "col4", "col5"]
+ assert rv.json["result"]["rowcount"] == 1
# empty results
rv = test_client.post(
@@ -547,9 +542,64 @@ def test_get_samples_with_filters(test_client, login_as_admin, virtual_dataset):
},
)
assert rv.status_code == 200
- rv_data = json.loads(rv.data)
- assert rv_data["result"]["colnames"] == []
- assert rv_data["result"]["rowcount"] == 0
+ assert rv.json["result"]["colnames"] == []
+ assert rv.json["result"]["rowcount"] == 0
+
+
+def test_get_samples_with_time_filter(test_client, login_as_admin, physical_dataset):
+ uri = (
+ f"/datasource/samples?datasource_id={physical_dataset.id}&datasource_type=table"
+ )
+ payload = {
+ "granularity": "col5",
+ "time_range": "2000-01-02 : 2000-01-04",
+ }
+ rv = test_client.post(uri, json=payload)
+ assert len(rv.json["result"]["data"]) == 2
+ if physical_dataset.database.backend != "sqlite":
+ assert [row["col5"] for row in rv.json["result"]["data"]] == [
+ 946771200000.0, # 2000-01-02 00:00:00
+ 946857600000.0, # 2000-01-03 00:00:00
+ ]
+ assert rv.json["result"]["page"] == 1
+ assert rv.json["result"]["per_page"] == app.config["SAMPLES_ROW_LIMIT"]
+ assert rv.json["result"]["total_count"] == 2
+
+
+def test_get_samples_with_multiple_filters(
+ test_client, login_as_admin, physical_dataset
+):
+ # 1. empty response
+ uri = (
+ f"/datasource/samples?datasource_id={physical_dataset.id}&datasource_type=table"
+ )
+ payload = {
+ "granularity": "col5",
+ "time_range": "2000-01-02 : 2000-01-04",
+ "filters": [
+ {"col": "col4", "op": "IS NOT NULL"},
+ ],
+ }
+ rv = test_client.post(uri, json=payload)
+ assert len(rv.json["result"]["data"]) == 0
+
+ # 2. adhoc filters, time filters, and custom where
+ payload = {
+ "granularity": "col5",
+ "time_range": "2000-01-02 : 2000-01-04",
+ "filters": [
+ {"col": "col2", "op": "==", "val": "c"},
+ ],
+ "extras": {"where": "col3 = 1.2 and col4 is null"},
+ }
+ rv = test_client.post(uri, json=payload)
+ assert len(rv.json["result"]["data"]) == 1
+ assert rv.json["result"]["total_count"] == 1
+ assert "2000-01-02" in rv.json["result"]["query"]
+ assert "2000-01-04" in rv.json["result"]["query"]
+ assert "col3 = 1.2" in rv.json["result"]["query"]
+ assert "col4 is null" in rv.json["result"]["query"]
+ assert "col2 = 'c'" in rv.json["result"]["query"]
def test_get_samples_pagination(test_client, login_as_admin, virtual_dataset):
@@ -558,10 +608,9 @@ def test_get_samples_pagination(test_client, login_as_admin, virtual_dataset):
f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table"
)
rv = test_client.post(uri)
- rv_data = json.loads(rv.data)
- assert rv_data["result"]["page"] == 1
- assert rv_data["result"]["per_page"] == app.config["SAMPLES_ROW_LIMIT"]
- assert rv_data["result"]["total_count"] == 10
+ assert rv.json["result"]["page"] == 1
+ assert rv.json["result"]["per_page"] == app.config["SAMPLES_ROW_LIMIT"]
+ assert rv.json["result"]["total_count"] == 10
# 2. incorrect per_page
per_pages = (app.config["SAMPLES_ROW_LIMIT"] + 1, 0, "xx")
@@ -582,25 +631,22 @@ def test_get_samples_pagination(test_client, login_as_admin, virtual_dataset):
# 4. turning pages
uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=1"
rv = test_client.post(uri)
- rv_data = json.loads(rv.data)
- assert rv_data["result"]["page"] == 1
- assert rv_data["result"]["per_page"] == 2
- assert rv_data["result"]["total_count"] == 10
- assert [row["col1"] for row in rv_data["result"]["data"]] == [0, 1]
+ assert rv.json["result"]["page"] == 1
+ assert rv.json["result"]["per_page"] == 2
+ assert rv.json["result"]["total_count"] == 10
+ assert [row["col1"] for row in rv.json["result"]["data"]] == [0, 1]
uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=2"
rv = test_client.post(uri)
- rv_data = json.loads(rv.data)
- assert rv_data["result"]["page"] == 2
- assert rv_data["result"]["per_page"] == 2
- assert rv_data["result"]["total_count"] == 10
- assert [row["col1"] for row in rv_data["result"]["data"]] == [2, 3]
+ assert rv.json["result"]["page"] == 2
+ assert rv.json["result"]["per_page"] == 2
+ assert rv.json["result"]["total_count"] == 10
+ assert [row["col1"] for row in rv.json["result"]["data"]] == [2, 3]
# 5. Exceeding the maximum pages
uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=6"
rv = test_client.post(uri)
- rv_data = json.loads(rv.data)
- assert rv_data["result"]["page"] == 6
- assert rv_data["result"]["per_page"] == 2
- assert rv_data["result"]["total_count"] == 10
- assert [row["col1"] for row in rv_data["result"]["data"]] == []
+ assert rv.json["result"]["page"] == 6
+ assert rv.json["result"]["per_page"] == 2
+ assert rv.json["result"]["total_count"] == 10
+ assert [row["col1"] for row in rv.json["result"]["data"]] == []