You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by sr...@apache.org on 2022/02/08 19:26:24 UTC

[superset] branch master updated: feat(chart-data-api): download multiple csvs as zip (#18618)

This is an automated email from the ASF dual-hosted git repository.

srini pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 125be78  feat(chart-data-api): download multiple csvs as zip (#18618)
125be78 is described below

commit 125be78ee6681b702ce5288657aba5ce190e7fce
Author: Ville Brofeldt <33...@users.noreply.github.com>
AuthorDate: Tue Feb 8 21:25:06 2022 +0200

    feat(chart-data-api): download multiple csvs as zip (#18618)
    
    * feat(chart-data-api): download multiple csvs as zip
    
    * break out util
    
    * check for empty request
---
 superset/charts/data/api.py                      | 28 ++++++++++++++++------
 superset/utils/core.py                           | 12 ++++++++++
 tests/integration_tests/charts/data/api_tests.py | 30 ++++++++++++++++++++++++
 3 files changed, 63 insertions(+), 7 deletions(-)

diff --git a/superset/charts/data/api.py b/superset/charts/data/api.py
index 6983152..d649042 100644
--- a/superset/charts/data/api.py
+++ b/superset/charts/data/api.py
@@ -21,7 +21,7 @@ import logging
 from typing import Any, Dict, Optional, TYPE_CHECKING
 
 import simplejson
-from flask import g, make_response, request
+from flask import current_app, g, make_response, request, Response
 from flask_appbuilder.api import expose, protect
 from flask_babel import gettext as _
 from marshmallow import ValidationError
@@ -44,13 +44,11 @@ from superset.connectors.base.models import BaseDatasource
 from superset.exceptions import QueryObjectValidationError
 from superset.extensions import event_logger
 from superset.utils.async_query_manager import AsyncQueryTokenException
-from superset.utils.core import json_int_dttm_ser
+from superset.utils.core import create_zip, json_int_dttm_ser
 from superset.views.base import CsvResponse, generate_download_headers
 from superset.views.base_api import statsd_metrics
 
 if TYPE_CHECKING:
-    from flask import Response
-
     from superset.common.query_context import QueryContext
 
 logger = logging.getLogger(__name__)
@@ -350,9 +348,25 @@ class ChartDataRestApi(ChartRestApi):
             if not security_manager.can_access("can_csv", "Superset"):
                 return self.response_403()
 
-            # return the first result
-            data = result["queries"][0]["data"]
-            return CsvResponse(data, headers=generate_download_headers("csv"))
+            if not result["queries"]:
+                return self.response_400(_("Empty query result"))
+
+            if len(result["queries"]) == 1:
+                # return single query results csv format
+                data = result["queries"][0]["data"]
+                return CsvResponse(data, headers=generate_download_headers("csv"))
+
+            # return multi-query csv results bundled as a zip file
+            encoding = current_app.config["CSV_EXPORT"].get("encoding", "utf-8")
+            files = {
+                f"query_{idx + 1}.csv": result["data"].encode(encoding)
+                for idx, result in enumerate(result["queries"])
+            }
+            return Response(
+                create_zip(files),
+                headers=generate_download_headers("zip"),
+                mimetype="application/zip",
+            )
 
         if result_format == ChartDataResultFormat.JSON:
             response_data = simplejson.dumps(
diff --git a/superset/utils/core.py b/superset/utils/core.py
index 4908fd9..da69a89 100644
--- a/superset/utils/core.py
+++ b/superset/utils/core.py
@@ -40,6 +40,7 @@ from email.mime.multipart import MIMEMultipart
 from email.mime.text import MIMEText
 from email.utils import formatdate
 from enum import Enum, IntEnum
+from io import BytesIO
 from timeit import default_timer
 from types import TracebackType
 from typing import (
@@ -61,6 +62,7 @@ from typing import (
     Union,
 )
 from urllib.parse import unquote_plus
+from zipfile import ZipFile
 
 import bleach
 import markdown as md
@@ -1788,3 +1790,13 @@ def apply_max_row_limit(limit: int, max_limit: Optional[int] = None,) -> int:
     if limit != 0:
         return min(max_limit, limit)
     return max_limit
+
+
+def create_zip(files: Dict[str, Any]) -> BytesIO:
+    buf = BytesIO()
+    with ZipFile(buf, "w") as bundle:
+        for filename, contents in files.items():
+            with bundle.open(filename, "w") as fp:
+                fp.write(contents)
+    buf.seek(0)
+    return buf
diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py
index d6ccd6a..6b04721 100644
--- a/tests/integration_tests/charts/data/api_tests.py
+++ b/tests/integration_tests/charts/data/api_tests.py
@@ -20,8 +20,11 @@ import json
 import unittest
 import copy
 from datetime import datetime
+from io import BytesIO
 from typing import Optional
 from unittest import mock
+from zipfile import ZipFile
+
 from flask import Response
 from tests.integration_tests.conftest import with_feature_flags
 from superset.models.sql_lab import Query
@@ -236,6 +239,16 @@ class TestPostChartDataApi(BaseTestChartDataApi):
         assert rv.status_code == 200
 
     @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
+    def test_empty_request_with_csv_result_format(self):
+        """
+        Chart data API: Test empty chart data with CSV result format
+        """
+        self.query_context_payload["result_format"] = "csv"
+        self.query_context_payload["queries"] = []
+        rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
+        assert rv.status_code == 400
+
+    @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
     def test_with_csv_result_format(self):
         """
         Chart data API: Test chart data with CSV result format
@@ -243,6 +256,22 @@ class TestPostChartDataApi(BaseTestChartDataApi):
         self.query_context_payload["result_format"] = "csv"
         rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
         assert rv.status_code == 200
+        assert rv.mimetype == "text/csv"
+
+    @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
+    def test_with_multi_query_csv_result_format(self):
+        """
+        Chart data API: Test chart data with multi-query CSV result format
+        """
+        self.query_context_payload["result_format"] = "csv"
+        self.query_context_payload["queries"].append(
+            self.query_context_payload["queries"][0]
+        )
+        rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
+        assert rv.status_code == 200
+        assert rv.mimetype == "application/zip"
+        zipfile = ZipFile(BytesIO(rv.data), "r")
+        assert zipfile.namelist() == ["query_1.csv", "query_2.csv"]
 
     @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
     def test_with_csv_result_format_when_actor_not_permitted_for_csv__403(self):
@@ -766,6 +795,7 @@ class TestGetChartDataApi(BaseTestChartDataApi):
             }
         )
         rv = self.get_assert_metric(f"api/v1/chart/{chart.id}/data/", "get_data")
+        assert rv.mimetype == "application/json"
         data = json.loads(rv.data.decode("utf-8"))
         assert data["result"][0]["status"] == "success"
         assert data["result"][0]["rowcount"] == 2