You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by yo...@apache.org on 2022/07/22 12:14:53 UTC
[superset] branch master updated: feat: the samples endpoint supports filters and pagination (#20683)
This is an automated email from the ASF dual-hosted git repository.
yongjiezhao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new f011abae2b feat: the samples endpoint supports filters and pagination (#20683)
f011abae2b is described below
commit f011abae2b1bbcffc4eddb1a88872dea622693fb
Author: Yongjie Zhao <yo...@gmail.com>
AuthorDate: Fri Jul 22 20:14:42 2022 +0800
feat: the samples endpoint supports filters and pagination (#20683)
---
.../cypress/integration/explore/control.test.ts | 2 +-
.../src/components/Chart/chartAction.js | 5 +-
.../DataTablesPane/test/SamplesPane.test.tsx | 12 +-
superset/datasets/api.py | 70 +-------
superset/datasets/commands/samples.py | 80 ---------
superset/explore/api.py | 84 +--------
superset/explore/commands/samples.py | 93 ----------
superset/views/datasource/schemas.py | 32 +++-
superset/views/datasource/utils.py | 115 ++++++++++++
superset/views/datasource/views.py | 24 +++
tests/integration_tests/conftest.py | 104 +++++++++++
tests/integration_tests/datasets/api_tests.py | 101 -----------
tests/integration_tests/datasource_tests.py | 194 ++++++++++++++++++++-
13 files changed, 479 insertions(+), 437 deletions(-)
diff --git a/superset-frontend/cypress-base/cypress/integration/explore/control.test.ts b/superset-frontend/cypress-base/cypress/integration/explore/control.test.ts
index f1adec9e44..a4b85de4de 100644
--- a/superset-frontend/cypress-base/cypress/integration/explore/control.test.ts
+++ b/superset-frontend/cypress-base/cypress/integration/explore/control.test.ts
@@ -129,7 +129,7 @@ describe('Test datatable', () => {
});
it('Datapane loads view samples', () => {
cy.intercept(
- 'api/v1/explore/samples?force=false&datasource_type=table&datasource_id=*',
+ 'datasource/samples?force=false&datasource_type=table&datasource_id=*',
).as('Samples');
cy.contains('Samples')
.click()
diff --git a/superset-frontend/src/components/Chart/chartAction.js b/superset-frontend/src/components/Chart/chartAction.js
index 139d91cd1d..044593eb37 100644
--- a/superset-frontend/src/components/Chart/chartAction.js
+++ b/superset-frontend/src/components/Chart/chartAction.js
@@ -602,10 +602,11 @@ export const getDatasourceSamples = async (
datasourceType,
datasourceId,
force,
+ jsonPayload,
) => {
- const endpoint = `/api/v1/explore/samples?force=${force}&datasource_type=${datasourceType}&datasource_id=${datasourceId}`;
+ const endpoint = `/datasource/samples?force=${force}&datasource_type=${datasourceType}&datasource_id=${datasourceId}`;
try {
- const response = await SupersetClient.get({ endpoint });
+ const response = await SupersetClient.post({ endpoint, jsonPayload });
return response.json.result;
} catch (err) {
const clientError = await getClientErrorObject(err);
diff --git a/superset-frontend/src/explore/components/DataTablesPane/test/SamplesPane.test.tsx b/superset-frontend/src/explore/components/DataTablesPane/test/SamplesPane.test.tsx
index 0aa0b03a06..391540f4d8 100644
--- a/superset-frontend/src/explore/components/DataTablesPane/test/SamplesPane.test.tsx
+++ b/superset-frontend/src/explore/components/DataTablesPane/test/SamplesPane.test.tsx
@@ -29,8 +29,8 @@ import { SamplesPane } from '../components';
import { createSamplesPaneProps } from './fixture';
describe('SamplesPane', () => {
- fetchMock.get(
- 'end:/api/v1/explore/samples?force=false&datasource_type=table&datasource_id=34',
+ fetchMock.post(
+ 'end:/datasource/samples?force=false&datasource_type=table&datasource_id=34',
{
result: {
data: [],
@@ -40,8 +40,8 @@ describe('SamplesPane', () => {
},
);
- fetchMock.get(
- 'end:/api/v1/explore/samples?force=true&datasource_type=table&datasource_id=35',
+ fetchMock.post(
+ 'end:/datasource/samples?force=true&datasource_type=table&datasource_id=35',
{
result: {
data: [
@@ -54,8 +54,8 @@ describe('SamplesPane', () => {
},
);
- fetchMock.get(
- 'end:/api/v1/explore/samples?force=false&datasource_type=table&datasource_id=36',
+ fetchMock.post(
+ 'end:/datasource/samples?force=false&datasource_type=table&datasource_id=36',
400,
);
diff --git a/superset/datasets/api.py b/superset/datasets/api.py
index f6890655ed..e25e8252f9 100644
--- a/superset/datasets/api.py
+++ b/superset/datasets/api.py
@@ -21,9 +21,8 @@ from io import BytesIO
from typing import Any
from zipfile import is_zipfile, ZipFile
-import simplejson
import yaml
-from flask import make_response, request, Response, send_file
+from flask import request, Response, send_file
from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_babel import ngettext
@@ -46,13 +45,11 @@ from superset.datasets.commands.exceptions import (
DatasetInvalidError,
DatasetNotFoundError,
DatasetRefreshFailedError,
- DatasetSamplesFailedError,
DatasetUpdateFailedError,
)
from superset.datasets.commands.export import ExportDatasetsCommand
from superset.datasets.commands.importers.dispatcher import ImportDatasetsCommand
from superset.datasets.commands.refresh import RefreshDatasetCommand
-from superset.datasets.commands.samples import SamplesDatasetCommand
from superset.datasets.commands.update import UpdateDatasetCommand
from superset.datasets.dao import DatasetDAO
from superset.datasets.filters import DatasetCertifiedFilter, DatasetIsNullOrEmptyFilter
@@ -63,7 +60,7 @@ from superset.datasets.schemas import (
get_delete_ids_schema,
get_export_ids_schema,
)
-from superset.utils.core import json_int_dttm_ser, parse_boolean_string
+from superset.utils.core import parse_boolean_string
from superset.views.base import DatasourceFilter, generate_download_headers
from superset.views.base_api import (
BaseSupersetModelRestApi,
@@ -93,7 +90,6 @@ class DatasetRestApi(BaseSupersetModelRestApi):
"bulk_delete",
"refresh",
"related_objects",
- "samples",
}
list_columns = [
"id",
@@ -775,65 +771,3 @@ class DatasetRestApi(BaseSupersetModelRestApi):
)
command.run()
return self.response(200, message="OK")
-
- @expose("/<pk>/samples")
- @protect()
- @safe
- @statsd_metrics
- @event_logger.log_this_with_context(
- action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.samples",
- log_to_statsd=False,
- )
- def samples(self, pk: int) -> Response:
- """get samples from a Dataset
- ---
- get:
- description: >-
- get samples from a Dataset
- parameters:
- - in: path
- schema:
- type: integer
- name: pk
- - in: query
- schema:
- type: boolean
- name: force
- responses:
- 200:
- description: Dataset samples
- content:
- application/json:
- schema:
- type: object
- properties:
- result:
- $ref: '#/components/schemas/ChartDataResponseResult'
- 401:
- $ref: '#/components/responses/401'
- 403:
- $ref: '#/components/responses/403'
- 404:
- $ref: '#/components/responses/404'
- 422:
- $ref: '#/components/responses/422'
- 500:
- $ref: '#/components/responses/500'
- """
- try:
- force = parse_boolean_string(request.args.get("force"))
- rv = SamplesDatasetCommand(pk, force).run()
- response_data = simplejson.dumps(
- {"result": rv},
- default=json_int_dttm_ser,
- ignore_nan=True,
- )
- resp = make_response(response_data, 200)
- resp.headers["Content-Type"] = "application/json; charset=utf-8"
- return resp
- except DatasetNotFoundError:
- return self.response_404()
- except DatasetForbiddenError:
- return self.response_403()
- except DatasetSamplesFailedError as ex:
- return self.response_400(message=str(ex))
diff --git a/superset/datasets/commands/samples.py b/superset/datasets/commands/samples.py
deleted file mode 100644
index e252cfb62f..0000000000
--- a/superset/datasets/commands/samples.py
+++ /dev/null
@@ -1,80 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import logging
-from typing import Any, Dict, Optional
-
-from superset import security_manager
-from superset.commands.base import BaseCommand
-from superset.common.chart_data import ChartDataResultType
-from superset.common.query_context_factory import QueryContextFactory
-from superset.common.utils.query_cache_manager import QueryCacheManager
-from superset.connectors.sqla.models import SqlaTable
-from superset.constants import CacheRegion
-from superset.datasets.commands.exceptions import (
- DatasetForbiddenError,
- DatasetNotFoundError,
- DatasetSamplesFailedError,
-)
-from superset.datasets.dao import DatasetDAO
-from superset.exceptions import SupersetSecurityException
-from superset.utils.core import QueryStatus
-
-logger = logging.getLogger(__name__)
-
-
-class SamplesDatasetCommand(BaseCommand):
- def __init__(self, model_id: int, force: bool):
- self._model_id = model_id
- self._force = force
- self._model: Optional[SqlaTable] = None
-
- def run(self) -> Dict[str, Any]:
- self.validate()
- if not self._model:
- raise DatasetNotFoundError()
-
- qc_instance = QueryContextFactory().create(
- datasource={
- "type": self._model.type,
- "id": self._model.id,
- },
- queries=[{}],
- result_type=ChartDataResultType.SAMPLES,
- force=self._force,
- )
- results = qc_instance.get_payload()
- try:
- sample_data = results["queries"][0]
- error_msg = sample_data.get("error")
- if sample_data.get("status") == QueryStatus.FAILED and error_msg:
- cache_key = sample_data.get("cache_key")
- QueryCacheManager.delete(cache_key, region=CacheRegion.DATA)
- raise DatasetSamplesFailedError(error_msg)
- return sample_data
- except (IndexError, KeyError) as exc:
- raise DatasetSamplesFailedError from exc
-
- def validate(self) -> None:
- # Validate/populate model exists
- self._model = DatasetDAO.find_by_id(self._model_id)
- if not self._model:
- raise DatasetNotFoundError()
- # Check ownership
- try:
- security_manager.raise_for_ownership(self._model)
- except SupersetSecurityException as ex:
- raise DatasetForbiddenError() from ex
diff --git a/superset/explore/api.py b/superset/explore/api.py
index 237eb67dbb..7cce592d36 100644
--- a/superset/explore/api.py
+++ b/superset/explore/api.py
@@ -16,22 +16,14 @@
# under the License.
import logging
-import simplejson
-from flask import g, make_response, request, Response
+from flask import g, request, Response
from flask_appbuilder.api import BaseApi, expose, protect, safe
from superset.charts.commands.exceptions import ChartNotFoundError
from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod
-from superset.dao.exceptions import DatasourceNotFound
from superset.explore.commands.get import GetExploreCommand
from superset.explore.commands.parameters import CommandParameters
-from superset.explore.commands.samples import SamplesDatasourceCommand
-from superset.explore.exceptions import (
- DatasetAccessDeniedError,
- DatasourceForbiddenError,
- DatasourceSamplesFailedError,
- WrongEndpointError,
-)
+from superset.explore.exceptions import DatasetAccessDeniedError, WrongEndpointError
from superset.explore.permalink.exceptions import ExplorePermalinkGetFailedError
from superset.explore.schemas import ExploreContextSchema
from superset.extensions import event_logger
@@ -39,16 +31,13 @@ from superset.temporary_cache.commands.exceptions import (
TemporaryCacheAccessDeniedError,
TemporaryCacheResourceNotFoundError,
)
-from superset.utils.core import json_int_dttm_ser, parse_boolean_string
logger = logging.getLogger(__name__)
class ExploreRestApi(BaseApi):
method_permission_name = MODEL_API_RW_METHOD_PERMISSION_MAP
- include_route_methods = {RouteMethod.GET} | {
- "samples",
- }
+ include_route_methods = {RouteMethod.GET}
allow_browser_login = True
class_permission_name = "Explore"
resource_name = "explore"
@@ -146,70 +135,3 @@ class ExploreRestApi(BaseApi):
return self.response(403, message=str(ex))
except TemporaryCacheResourceNotFoundError as ex:
return self.response(404, message=str(ex))
-
- @expose("/samples", methods=["GET"])
- @protect()
- @safe
- @event_logger.log_this_with_context(
- action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.samples",
- log_to_statsd=False,
- )
- def samples(self) -> Response:
- """get samples from a Datasource
- ---
- get:
- description: >-
- get samples from a Datasource
- parameters:
- - in: path
- schema:
- type: integer
- name: pk
- - in: query
- schema:
- type: boolean
- name: force
- responses:
- 200:
- description: Datasource samples
- content:
- application/json:
- schema:
- type: object
- properties:
- result:
- $ref: '#/components/schemas/ChartDataResponseResult'
- 401:
- $ref: '#/components/responses/401'
- 403:
- $ref: '#/components/responses/403'
- 404:
- $ref: '#/components/responses/404'
- 422:
- $ref: '#/components/responses/422'
- 500:
- $ref: '#/components/responses/500'
- """
- try:
- force = parse_boolean_string(request.args.get("force"))
- rv = SamplesDatasourceCommand(
- user=g.user,
- datasource_type=request.args.get("datasource_type", type=str),
- datasource_id=request.args.get("datasource_id", type=int),
- force=force,
- ).run()
-
- response_data = simplejson.dumps(
- {"result": rv},
- default=json_int_dttm_ser,
- ignore_nan=True,
- )
- resp = make_response(response_data, 200)
- resp.headers["Content-Type"] = "application/json; charset=utf-8"
- return resp
- except DatasourceNotFound:
- return self.response_404()
- except DatasourceForbiddenError:
- return self.response_403()
- except DatasourceSamplesFailedError as ex:
- return self.response_400(message=str(ex))
diff --git a/superset/explore/commands/samples.py b/superset/explore/commands/samples.py
deleted file mode 100644
index 7fda5c1bc1..0000000000
--- a/superset/explore/commands/samples.py
+++ /dev/null
@@ -1,93 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import logging
-from typing import Any, Dict, Optional
-
-from flask_appbuilder.security.sqla.models import User
-
-from superset import db, security_manager
-from superset.commands.base import BaseCommand
-from superset.common.chart_data import ChartDataResultType
-from superset.common.query_context_factory import QueryContextFactory
-from superset.common.utils.query_cache_manager import QueryCacheManager
-from superset.constants import CacheRegion
-from superset.dao.exceptions import DatasourceNotFound
-from superset.datasource.dao import Datasource, DatasourceDAO
-from superset.exceptions import SupersetSecurityException
-from superset.explore.exceptions import (
- DatasourceForbiddenError,
- DatasourceSamplesFailedError,
-)
-from superset.utils.core import DatasourceType, QueryStatus
-
-logger = logging.getLogger(__name__)
-
-
-class SamplesDatasourceCommand(BaseCommand):
- def __init__(
- self,
- user: User,
- datasource_id: Optional[int],
- datasource_type: Optional[str],
- force: bool,
- ):
- self._actor = user
- self._datasource_id = datasource_id
- self._datasource_type = datasource_type
- self._force = force
- self._model: Optional[Datasource] = None
-
- def run(self) -> Dict[str, Any]:
- self.validate()
- if not self._model:
- raise DatasourceNotFound()
-
- qc_instance = QueryContextFactory().create(
- datasource={
- "type": self._model.type,
- "id": self._model.id,
- },
- queries=[{}],
- result_type=ChartDataResultType.SAMPLES,
- force=self._force,
- )
- results = qc_instance.get_payload()
- try:
- sample_data = results["queries"][0]
- error_msg = sample_data.get("error")
- if sample_data.get("status") == QueryStatus.FAILED and error_msg:
- cache_key = sample_data.get("cache_key")
- QueryCacheManager.delete(cache_key, region=CacheRegion.DATA)
- raise DatasourceSamplesFailedError(error_msg)
- return sample_data
- except (IndexError, KeyError) as exc:
- raise DatasourceSamplesFailedError from exc
-
- def validate(self) -> None:
- # Validate/populate model exists
- if self._datasource_type and self._datasource_id:
- self._model = DatasourceDAO.get_datasource(
- session=db.session,
- datasource_type=DatasourceType(self._datasource_type),
- datasource_id=self._datasource_id,
- )
-
- # Check ownership
- try:
- security_manager.raise_for_ownership(self._model)
- except SupersetSecurityException as ex:
- raise DatasourceForbiddenError() from ex
diff --git a/superset/views/datasource/schemas.py b/superset/views/datasource/schemas.py
index 64b2b854bb..4c97f17e88 100644
--- a/superset/views/datasource/schemas.py
+++ b/superset/views/datasource/schemas.py
@@ -14,11 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any
+from typing import Any, Dict
-from marshmallow import fields, post_load, Schema
+from marshmallow import fields, post_load, pre_load, Schema, validate
from typing_extensions import TypedDict
+from superset import app
+from superset.charts.schemas import ChartDataFilterSchema
+from superset.utils.core import DatasourceType
+
class ExternalMetadataParams(TypedDict):
datasource_type: str
@@ -54,3 +58,27 @@ class ExternalMetadataSchema(Schema):
schema_name=data.get("schema_name", ""),
table_name=data["table_name"],
)
+
+
+class SamplesPayloadSchema(Schema):
+ filters = fields.List(fields.Nested(ChartDataFilterSchema), required=False)
+
+ @pre_load
+ # pylint: disable=no-self-use, unused-argument
+ def handle_none(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
+ if data is None:
+ return {}
+ return data
+
+
+class SamplesRequestSchema(Schema):
+ datasource_type = fields.String(
+ validate=validate.OneOf([e.value for e in DatasourceType]), required=True
+ )
+ datasource_id = fields.Integer(required=True)
+ force = fields.Boolean(load_default=False)
+ page = fields.Integer(load_default=1)
+ per_page = fields.Integer(
+ validate=validate.Range(min=1, max=app.config.get("SAMPLES_ROW_LIMIT", 1000)),
+ load_default=app.config.get("SAMPLES_ROW_LIMIT", 1000),
+ )
diff --git a/superset/views/datasource/utils.py b/superset/views/datasource/utils.py
new file mode 100644
index 0000000000..0191db2947
--- /dev/null
+++ b/superset/views/datasource/utils.py
@@ -0,0 +1,115 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from typing import Any, Dict, Optional
+
+from superset import app, db
+from superset.common.chart_data import ChartDataResultType
+from superset.common.query_context_factory import QueryContextFactory
+from superset.common.utils.query_cache_manager import QueryCacheManager
+from superset.constants import CacheRegion
+from superset.datasets.commands.exceptions import DatasetSamplesFailedError
+from superset.datasource.dao import DatasourceDAO
+from superset.utils.core import QueryStatus
+from superset.views.datasource.schemas import SamplesPayloadSchema
+
+
+def get_limit_clause(page: Optional[int], per_page: Optional[int]) -> Dict[str, int]:
+ samples_row_limit = app.config.get("SAMPLES_ROW_LIMIT", 1000)
+ limit = samples_row_limit
+ offset = 0
+
+ if isinstance(page, int) and isinstance(per_page, int):
+ limit = int(per_page)
+ if limit < 0 or limit > samples_row_limit:
+ # reset limit value if input is invalid
+ limit = samples_row_limit
+
+ offset = max((int(page) - 1) * limit, 0)
+
+ return {"row_offset": offset, "row_limit": limit}
+
+
+def get_samples( # pylint: disable=too-many-arguments,too-many-locals
+ datasource_type: str,
+ datasource_id: int,
+ force: bool = False,
+ page: int = 1,
+ per_page: int = 1000,
+ payload: Optional[SamplesPayloadSchema] = None,
+) -> Dict[str, Any]:
+ datasource = DatasourceDAO.get_datasource(
+ session=db.session,
+ datasource_type=datasource_type,
+ datasource_id=datasource_id,
+ )
+
+ limit_clause = get_limit_clause(page, per_page)
+
+ # todo(yongjie): Constructing count(*) and samples in the same query_context,
+ # then remove query_type==SAMPLES
+ # constructing samples query
+ samples_instance = QueryContextFactory().create(
+ datasource={
+ "type": datasource.type,
+ "id": datasource.id,
+ },
+ queries=[{**payload, **limit_clause} if payload else limit_clause],
+ result_type=ChartDataResultType.SAMPLES,
+ force=force,
+ )
+
+ # constructing count(*) query
+ count_star_metric = {
+ "metrics": [
+ {
+ "expressionType": "SQL",
+ "sqlExpression": "COUNT(*)",
+ "label": "COUNT(*)",
+ }
+ ]
+ }
+ count_star_instance = QueryContextFactory().create(
+ datasource={
+ "type": datasource.type,
+ "id": datasource.id,
+ },
+ queries=[{**payload, **count_star_metric} if payload else count_star_metric],
+ result_type=ChartDataResultType.FULL,
+ force=force,
+ )
+ samples_results = samples_instance.get_payload()
+ count_star_results = count_star_instance.get_payload()
+
+ try:
+ sample_data = samples_results["queries"][0]
+ count_star_data = count_star_results["queries"][0]
+ failed_status = (
+ sample_data.get("status") == QueryStatus.FAILED
+ or count_star_data.get("status") == QueryStatus.FAILED
+ )
+ error_msg = sample_data.get("error") or count_star_data.get("error")
+ if failed_status and error_msg:
+ cache_key = sample_data.get("cache_key")
+ QueryCacheManager.delete(cache_key, region=CacheRegion.DATA)
+ raise DatasetSamplesFailedError(error_msg)
+
+ sample_data["page"] = page
+ sample_data["per_page"] = per_page
+ sample_data["total_count"] = count_star_data["data"][0]["COUNT(*)"]
+ return sample_data
+ except (IndexError, KeyError) as exc:
+ raise DatasetSamplesFailedError from exc
diff --git a/superset/views/datasource/views.py b/superset/views/datasource/views.py
index 4e43068c6f..60ee4baddc 100644
--- a/superset/views/datasource/views.py
+++ b/superset/views/datasource/views.py
@@ -50,7 +50,10 @@ from superset.views.datasource.schemas import (
ExternalMetadataParams,
ExternalMetadataSchema,
get_external_metadata_schema,
+ SamplesPayloadSchema,
+ SamplesRequestSchema,
)
+from superset.views.datasource.utils import get_samples
from superset.views.utils import sanitize_datasource_data
@@ -179,3 +182,24 @@ class Datasource(BaseSupersetView):
except (NoResultFound, NoSuchTableError) as ex:
raise DatasetNotFoundError() from ex
return self.json_response(external_metadata)
+
+ @expose("/samples", methods=["POST"])
+ @has_access_api
+ @api
+ @handle_api_exception
+ def samples(self) -> FlaskResponse:
+ try:
+ params = SamplesRequestSchema().load(request.args)
+ payload = SamplesPayloadSchema().load(request.json)
+ except ValidationError as err:
+ return json_error_response(err.messages, status=400)
+
+ rv = get_samples(
+ datasource_type=params["datasource_type"],
+ datasource_id=params["datasource_id"],
+ force=params["force"],
+ page=params["page"],
+ per_page=params["per_page"],
+ payload=payload,
+ )
+ return self.json_response({"result": rv})
diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py
index ea46039d84..6675509d68 100644
--- a/tests/integration_tests/conftest.py
+++ b/tests/integration_tests/conftest.py
@@ -206,3 +206,107 @@ def with_feature_flags(**mock_feature_flags):
return functools.update_wrapper(wrapper, test_fn)
return decorate
+
+
+@pytest.fixture
+def virtual_dataset():
+ from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
+
+ dataset = SqlaTable(
+ table_name="virtual_dataset",
+ sql=(
+ "SELECT 0 as col1, 'a' as col2, 1.0 as col3, NULL as col4, '2000-01-01 00:00:00' as col5 "
+ "UNION ALL "
+ "SELECT 1, 'b', 1.1, NULL, '2000-01-02 00:00:00' "
+ "UNION ALL "
+ "SELECT 2 as col1, 'c' as col2, 1.2, NULL, '2000-01-03 00:00:00' "
+ "UNION ALL "
+ "SELECT 3 as col1, 'd' as col2, 1.3, NULL, '2000-01-04 00:00:00' "
+ "UNION ALL "
+ "SELECT 4 as col1, 'e' as col2, 1.4, NULL, '2000-01-05 00:00:00' "
+ "UNION ALL "
+ "SELECT 5 as col1, 'f' as col2, 1.5, NULL, '2000-01-06 00:00:00' "
+ "UNION ALL "
+ "SELECT 6 as col1, 'g' as col2, 1.6, NULL, '2000-01-07 00:00:00' "
+ "UNION ALL "
+ "SELECT 7 as col1, 'h' as col2, 1.7, NULL, '2000-01-08 00:00:00' "
+ "UNION ALL "
+ "SELECT 8 as col1, 'i' as col2, 1.8, NULL, '2000-01-09 00:00:00' "
+ "UNION ALL "
+ "SELECT 9 as col1, 'j' as col2, 1.9, NULL, '2000-01-10 00:00:00' "
+ ),
+ database=get_example_database(),
+ )
+ TableColumn(column_name="col1", type="INTEGER", table=dataset)
+ TableColumn(column_name="col2", type="VARCHAR(255)", table=dataset)
+ TableColumn(column_name="col3", type="DECIMAL(4,2)", table=dataset)
+ TableColumn(column_name="col4", type="VARCHAR(255)", table=dataset)
+ # Different database dialect datetime type is not consistent, so temporarily use varchar
+ TableColumn(column_name="col5", type="VARCHAR(255)", table=dataset)
+
+ SqlMetric(metric_name="count", expression="count(*)", table=dataset)
+ db.session.merge(dataset)
+
+ yield dataset
+
+ db.session.delete(dataset)
+ db.session.commit()
+
+
+@pytest.fixture
+def physical_dataset():
+ from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
+
+ example_database = get_example_database()
+ engine = example_database.get_sqla_engine()
+ # sqlite can only execute one statement at a time
+ engine.execute(
+ """
+ CREATE TABLE IF NOT EXISTS physical_dataset(
+ col1 INTEGER,
+ col2 VARCHAR(255),
+ col3 DECIMAL(4,2),
+ col4 VARCHAR(255),
+ col5 VARCHAR(255)
+ );
+ """
+ )
+ engine.execute(
+ """
+ INSERT INTO physical_dataset values
+ (0, 'a', 1.0, NULL, '2000-01-01 00:00:00'),
+ (1, 'b', 1.1, NULL, '2000-01-02 00:00:00'),
+ (2, 'c', 1.2, NULL, '2000-01-03 00:00:00'),
+ (3, 'd', 1.3, NULL, '2000-01-04 00:00:00'),
+ (4, 'e', 1.4, NULL, '2000-01-05 00:00:00'),
+ (5, 'f', 1.5, NULL, '2000-01-06 00:00:00'),
+ (6, 'g', 1.6, NULL, '2000-01-07 00:00:00'),
+ (7, 'h', 1.7, NULL, '2000-01-08 00:00:00'),
+ (8, 'i', 1.8, NULL, '2000-01-09 00:00:00'),
+ (9, 'j', 1.9, NULL, '2000-01-10 00:00:00');
+ """
+ )
+
+ dataset = SqlaTable(
+ table_name="physical_dataset",
+ database=example_database,
+ )
+ TableColumn(column_name="col1", type="INTEGER", table=dataset)
+ TableColumn(column_name="col2", type="VARCHAR(255)", table=dataset)
+ TableColumn(column_name="col3", type="DECIMAL(4,2)", table=dataset)
+ TableColumn(column_name="col4", type="VARCHAR(255)", table=dataset)
+ TableColumn(column_name="col5", type="VARCHAR(255)", table=dataset)
+ SqlMetric(metric_name="count", expression="count(*)", table=dataset)
+ db.session.merge(dataset)
+ if example_database.backend == "sqlite":
+ db.session.commit()
+
+ yield dataset
+
+ engine.execute(
+ """
+ DROP TABLE physical_dataset;
+ """
+ )
+ db.session.delete(dataset)
+ db.session.commit()
diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py
index d8e756e98e..46739f9631 100644
--- a/tests/integration_tests/datasets/api_tests.py
+++ b/tests/integration_tests/datasets/api_tests.py
@@ -27,9 +27,7 @@ import pytest
import yaml
from sqlalchemy.sql import func
-from superset.common.utils.query_cache_manager import QueryCacheManager
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
-from superset.constants import CacheRegion
from superset.dao.exceptions import (
DAOCreateFailedError,
DAODeleteFailedError,
@@ -2085,102 +2083,3 @@ class TestDatasetApi(SupersetTestCase):
db.session.delete(table_w_certification)
db.session.commit()
-
- @pytest.mark.usefixtures("create_datasets")
- def test_get_dataset_samples(self):
- """
- Dataset API: Test get dataset samples
- """
- if backend() == "sqlite":
- return
-
- dataset = self.get_fixture_datasets()[0]
-
- self.login(username="admin")
- uri = f"api/v1/dataset/{dataset.id}/samples"
-
- # 1. should cache data
- # feeds data
- self.client.get(uri)
- # get from cache
- rv = self.client.get(uri)
- rv_data = json.loads(rv.data)
- assert rv.status_code == 200
- assert "result" in rv_data
- assert rv_data["result"]["cached_dttm"] is not None
- cache_key1 = rv_data["result"]["cache_key"]
- assert QueryCacheManager.has(cache_key1, region=CacheRegion.DATA)
-
- # 2. should through cache
- uri2 = f"api/v1/dataset/{dataset.id}/samples?force=true"
- # feeds data
- self.client.get(uri2)
- # force query
- rv2 = self.client.get(uri2)
- rv_data2 = json.loads(rv2.data)
- assert rv_data2["result"]["cached_dttm"] is None
- cache_key2 = rv_data2["result"]["cache_key"]
- assert QueryCacheManager.has(cache_key2, region=CacheRegion.DATA)
-
- # 3. data precision
- assert "colnames" in rv_data2["result"]
- assert "coltypes" in rv_data2["result"]
- assert "data" in rv_data2["result"]
-
- eager_samples = dataset.database.get_df(
- f"select * from {dataset.table_name}"
- f' limit {self.app.config["SAMPLES_ROW_LIMIT"]}'
- ).to_dict(orient="records")
- assert eager_samples == rv_data2["result"]["data"]
-
- @pytest.mark.usefixtures("create_datasets")
- def test_get_dataset_samples_with_failed_cc(self):
- if backend() == "sqlite":
- return
-
- dataset = self.get_fixture_datasets()[0]
-
- self.login(username="admin")
- failed_column = TableColumn(
- column_name="DUMMY CC",
- type="VARCHAR(255)",
- table=dataset,
- expression="INCORRECT SQL",
- )
- uri = f"api/v1/dataset/{dataset.id}/samples"
- dataset.columns.append(failed_column)
- rv = self.client.get(uri)
- assert rv.status_code == 400
- rv_data = json.loads(rv.data)
- assert "message" in rv_data
- if dataset.database.db_engine_spec.engine_name == "PostgreSQL":
- assert "INCORRECT SQL" in rv_data.get("message")
-
- def test_get_dataset_samples_on_virtual_dataset(self):
- if backend() == "sqlite":
- return
-
- virtual_dataset = SqlaTable(
- table_name="virtual_dataset",
- sql=("SELECT 'foo' as foo, 'bar' as bar"),
- database=get_example_database(),
- )
- TableColumn(column_name="foo", type="VARCHAR(255)", table=virtual_dataset)
- TableColumn(column_name="bar", type="VARCHAR(255)", table=virtual_dataset)
- SqlMetric(metric_name="count", expression="count(*)", table=virtual_dataset)
-
- self.login(username="admin")
- uri = f"api/v1/dataset/{virtual_dataset.id}/samples"
- rv = self.client.get(uri)
- assert rv.status_code == 200
- rv_data = json.loads(rv.data)
- cache_key = rv_data["result"]["cache_key"]
- assert QueryCacheManager.has(cache_key, region=CacheRegion.DATA)
-
- # remove original column in dataset
- virtual_dataset.sql = "SELECT 'foo' as foo"
- rv = self.client.get(uri)
- assert rv.status_code == 400
-
- db.session.delete(virtual_dataset)
- db.session.commit()
diff --git a/tests/integration_tests/datasource_tests.py b/tests/integration_tests/datasource_tests.py
index 8e4d269b20..ad4d625cc5 100644
--- a/tests/integration_tests/datasource_tests.py
+++ b/tests/integration_tests/datasource_tests.py
@@ -23,13 +23,15 @@ import prison
import pytest
from superset import app, db
-from superset.connectors.sqla.models import SqlaTable
+from superset.common.utils.query_cache_manager import QueryCacheManager
+from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
+from superset.constants import CacheRegion
from superset.dao.exceptions import DatasourceNotFound, DatasourceTypeNotSupportedError
from superset.datasets.commands.exceptions import DatasetNotFoundError
from superset.exceptions import SupersetGenericDBErrorException
from superset.models.core import Database
-from superset.utils.core import DatasourceType, get_example_default_schema
-from superset.utils.database import get_example_database
+from superset.utils.core import backend, get_example_default_schema
+from superset.utils.database import get_example_database, get_main_database
from tests.integration_tests.base_tests import db_insert_temp_object, SupersetTestCase
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
@@ -416,3 +418,189 @@ class TestDatasource(SupersetTestCase):
self.login(username="admin")
resp = self.get_json_resp("/datasource/get/druid/500000/", raise_on_error=False)
self.assertEqual(resp.get("error"), "'druid' is not a valid DatasourceType")
+
+
+def test_get_samples(test_client, login_as_admin, virtual_dataset):
+ """
+ Dataset API: Test get dataset samples
+ """
+ # 1. should cache data
+ uri = (
+ f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table"
+ )
+ # feeds data
+ test_client.post(uri)
+ # get from cache
+ rv = test_client.post(uri)
+ rv_data = json.loads(rv.data)
+ assert rv.status_code == 200
+ assert len(rv_data["result"]["data"]) == 10
+ assert QueryCacheManager.has(
+ rv_data["result"]["cache_key"],
+ region=CacheRegion.DATA,
+ )
+ assert rv_data["result"]["is_cached"]
+
+ # 2. should read through cache data
+ uri2 = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&force=true"
+ # feeds data
+ test_client.post(uri2)
+ # force query
+ rv2 = test_client.post(uri2)
+ rv_data2 = json.loads(rv2.data)
+ assert rv2.status_code == 200
+ assert len(rv_data2["result"]["data"]) == 10
+ assert QueryCacheManager.has(
+ rv_data2["result"]["cache_key"],
+ region=CacheRegion.DATA,
+ )
+ assert not rv_data2["result"]["is_cached"]
+
+ # 3. data precision
+ assert "colnames" in rv_data2["result"]
+ assert "coltypes" in rv_data2["result"]
+ assert "data" in rv_data2["result"]
+
+ eager_samples = virtual_dataset.database.get_df(
+ f"select * from ({virtual_dataset.sql}) as tbl"
+ f' limit {app.config["SAMPLES_ROW_LIMIT"]}'
+ )
+ # the col3 is Decimal
+ eager_samples["col3"] = eager_samples["col3"].apply(float)
+ eager_samples = eager_samples.to_dict(orient="records")
+ assert eager_samples == rv_data2["result"]["data"]
+
+
+def test_get_samples_with_incorrect_cc(test_client, login_as_admin, virtual_dataset):
+ TableColumn(
+ column_name="DUMMY CC",
+ type="VARCHAR(255)",
+ table=virtual_dataset,
+ expression="INCORRECT SQL",
+ )
+ db.session.merge(virtual_dataset)
+
+ uri = (
+ f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table"
+ )
+ rv = test_client.post(uri)
+ assert rv.status_code == 422
+
+ rv_data = json.loads(rv.data)
+ assert "error" in rv_data
+ if virtual_dataset.database.db_engine_spec.engine_name == "PostgreSQL":
+ assert "INCORRECT SQL" in rv_data.get("error")
+
+
+def test_get_samples_on_physical_dataset(test_client, login_as_admin, physical_dataset):
+ uri = (
+ f"/datasource/samples?datasource_id={physical_dataset.id}&datasource_type=table"
+ )
+ rv = test_client.post(uri)
+ assert rv.status_code == 200
+ rv_data = json.loads(rv.data)
+ assert QueryCacheManager.has(
+ rv_data["result"]["cache_key"], region=CacheRegion.DATA
+ )
+ assert len(rv_data["result"]["data"]) == 10
+
+
+def test_get_samples_with_filters(test_client, login_as_admin, virtual_dataset):
+ uri = (
+ f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table"
+ )
+ rv = test_client.post(uri, json=None)
+ assert rv.status_code == 200
+
+ rv = test_client.post(uri, json={})
+ assert rv.status_code == 200
+
+ rv = test_client.post(uri, json={"foo": "bar"})
+ assert rv.status_code == 400
+
+ rv = test_client.post(
+ uri, json={"filters": [{"col": "col1", "op": "INVALID", "val": 0}]}
+ )
+ assert rv.status_code == 400
+
+ rv = test_client.post(
+ uri,
+ json={
+ "filters": [
+ {"col": "col2", "op": "==", "val": "a"},
+ {"col": "col1", "op": "==", "val": 0},
+ ]
+ },
+ )
+ assert rv.status_code == 200
+ rv_data = json.loads(rv.data)
+ assert rv_data["result"]["colnames"] == ["col1", "col2", "col3", "col4", "col5"]
+ assert rv_data["result"]["rowcount"] == 1
+
+ # empty results
+ rv = test_client.post(
+ uri,
+ json={
+ "filters": [
+ {"col": "col2", "op": "==", "val": "x"},
+ ]
+ },
+ )
+ assert rv.status_code == 200
+ rv_data = json.loads(rv.data)
+ assert rv_data["result"]["colnames"] == []
+ assert rv_data["result"]["rowcount"] == 0
+
+
+def test_get_samples_pagination(test_client, login_as_admin, virtual_dataset):
+ # 1. default page, per_page and total_count
+ uri = (
+ f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table"
+ )
+ rv = test_client.post(uri)
+ rv_data = json.loads(rv.data)
+ assert rv_data["result"]["page"] == 1
+ assert rv_data["result"]["per_page"] == app.config["SAMPLES_ROW_LIMIT"]
+ assert rv_data["result"]["total_count"] == 10
+
+ # 2. incorrect per_page
+ per_pages = (app.config["SAMPLES_ROW_LIMIT"] + 1, 0, "xx")
+ for per_page in per_pages:
+ uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page={per_page}"
+ rv = test_client.post(uri)
+ assert rv.status_code == 400
+
+ # 3. incorrect page or datasource_type
+ uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&page=xx"
+ rv = test_client.post(uri)
+ assert rv.status_code == 400
+
+ uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=xx"
+ rv = test_client.post(uri)
+ assert rv.status_code == 400
+
+ # 4. turning pages
+ uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=1"
+ rv = test_client.post(uri)
+ rv_data = json.loads(rv.data)
+ assert rv_data["result"]["page"] == 1
+ assert rv_data["result"]["per_page"] == 2
+ assert rv_data["result"]["total_count"] == 10
+ assert [row["col1"] for row in rv_data["result"]["data"]] == [0, 1]
+
+ uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=2"
+ rv = test_client.post(uri)
+ rv_data = json.loads(rv.data)
+ assert rv_data["result"]["page"] == 2
+ assert rv_data["result"]["per_page"] == 2
+ assert rv_data["result"]["total_count"] == 10
+ assert [row["col1"] for row in rv_data["result"]["data"]] == [2, 3]
+
+ # 5. Exceeding the maximum pages
+ uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=6"
+ rv = test_client.post(uri)
+ rv_data = json.loads(rv.data)
+ assert rv_data["result"]["page"] == 6
+ assert rv_data["result"]["per_page"] == 2
+ assert rv_data["result"]["total_count"] == 10
+ assert [row["col1"] for row in rv_data["result"]["data"]] == []