You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by hu...@apache.org on 2021/03/05 17:52:03 UTC

[superset] branch hugh-refactor-fixed-ci updated: got the test to work

This is an automated email from the ASF dual-hosted git repository.

hugh pushed a commit to branch hugh-refactor-fixed-ci
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/hugh-refactor-fixed-ci by this push:
     new ff91073  got the test to work
ff91073 is described below

commit ff9107347cce41c8d201d6b240c5499985606de6
Author: hughhhh <hu...@gmail.com>
AuthorDate: Fri Mar 5 12:50:25 2021 -0500

    got the test to work
---
 superset/sql_lab.py   |  1 +
 superset/utils/log.py | 86 +++++++++++++++++++++++++++++++++++++++------------
 tests/celery_tests.py | 31 ++++++++++++-------
 3 files changed, 86 insertions(+), 32 deletions(-)

diff --git a/superset/sql_lab.py b/superset/sql_lab.py
index dd82826..5ffa1ec 100644
--- a/superset/sql_lab.py
+++ b/superset/sql_lab.py
@@ -419,6 +419,7 @@ def execute_sql_statements(  # pylint: disable=too-many-arguments, too-many-loca
         # Commit the connection so CTA queries will create the table.
         conn.commit()
 
+    logger.info("hello")
     # Success, updating the query entry in database
     query.rows = result_set.size
     query.progress = 100
diff --git a/superset/utils/log.py b/superset/utils/log.py
index 824487e..243b52d 100644
--- a/superset/utils/log.py
+++ b/superset/utils/log.py
@@ -19,9 +19,9 @@ import inspect
 import json
 import logging
 import textwrap
-import time
 from abc import ABC, abstractmethod
 from contextlib import contextmanager
+from datetime import datetime, timedelta
 from typing import Any, Callable, cast, Dict, Iterator, Optional, Type, Union
 
 from flask import current_app, g, request
@@ -58,6 +58,35 @@ def collect_request_payload() -> Dict[str, Any]:
 
 
 class AbstractEventLogger(ABC):
+    def __call__(
+        self,
+        action: str,
+        object_ref: Optional[str] = None,
+        log_to_statsd: bool = True,
+        duration: Optional[timedelta] = None,
+        **payload_override: Dict[str, Any],
+    ) -> object:
+        # pylint: disable=W0201
+        self.action = action
+        self.object_ref = object_ref
+        self.log_to_statsd = log_to_statsd
+        self.payload_override = payload_override
+        return self
+
+    def __enter__(self) -> None:
+        # pylint: disable=W0201
+        self.start = datetime.now()
+
+    def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
+        # Log data w/ arguments being passed in
+        self.log_with_context(
+            action=self.action,
+            object_ref=self.object_ref,
+            log_to_statsd=self.log_to_statsd,
+            duration=datetime.now() - self.start,
+            **self.payload_override,
+        )
+
     @abstractmethod
     def log(  # pylint: disable=too-many-arguments
         self,
@@ -72,32 +101,24 @@ class AbstractEventLogger(ABC):
     ) -> None:
         pass
 
-    @contextmanager
-    def log_context(  # pylint: disable=too-many-locals
-        self, action: str, object_ref: Optional[str] = None, log_to_statsd: bool = True,
-    ) -> Iterator[Callable[..., None]]:
-        """
-        Log an event with additional information from the request context.
-
-        :param action: a name to identify the event
-        :param object_ref: reference to the Python object that triggered this action
-        :param log_to_statsd: whether to update statsd counter for the action
-        """
+    def log_with_context(  # pylint: disable=too-many-locals
+        self,
+        action: str,
+        duration: timedelta,
+        object_ref: Optional[str] = None,
+        log_to_statsd: bool = True,
+        **payload_override: Optional[Dict[str, Any]],
+    ) -> None:
         from superset.views.core import get_form_data
 
-        start_time = time.time()
         referrer = request.referrer[:1000] if request.referrer else None
         user_id = g.user.get_id() if hasattr(g, "user") and g.user else None
-        payload_override = {}
-
-        # yield a helper to add additional payload
-        yield lambda **kwargs: payload_override.update(kwargs)
 
         payload = collect_request_payload()
         if object_ref:
             payload["object_ref"] = object_ref
-        # manual updates from context comes the last
-        payload.update(payload_override)
+        if payload_override:
+            payload.update(payload_override)
 
         dashboard_id: Optional[int] = None
         try:
@@ -133,10 +154,35 @@ class AbstractEventLogger(ABC):
             records=records,
             dashboard_id=dashboard_id,
             slice_id=slice_id,
-            duration_ms=round((time.time() - start_time) * 1000),
+            duration_ms=int(duration.total_seconds() * 1000),
             referrer=referrer,
         )
 
+    @contextmanager
+    def log_context(  # pylint: disable=too-many-locals
+        self, action: str, object_ref: Optional[str] = None, log_to_statsd: bool = True,
+    ) -> Iterator[Callable[..., None]]:
+        """
+        Log an event with additional information from the request context.
+        :param action: a name to identify the event
+        :param object_ref: reference to the Python object that triggered this action
+        :param log_to_statsd: whether to update statsd counter for the action
+        """
+        logging.info("in event_logger1")
+        payload_override = {}
+        start = datetime.now()
+        # yield a helper to add additional payload
+        yield lambda **kwargs: payload_override.update(kwargs)
+        duration = datetime.now() - start
+
+        logging.info("in event_logger2")
+
+        self.log_with_context(
+            action, duration, object_ref, log_to_statsd, **payload_override
+        )
+
+        logging.info("in event_logger3")
+
     def _wrapper(
         self,
         f: Callable[..., Any],
diff --git a/tests/celery_tests.py b/tests/celery_tests.py
index 13c7ac3..79458bb 100644
--- a/tests/celery_tests.py
+++ b/tests/celery_tests.py
@@ -28,12 +28,12 @@ from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with
 import pytest
 
 import flask
-from flask import current_app
+from flask import current_app, g
 
 from tests.base_tests import login
 from tests.conftest import CTAS_SCHEMA_NAME
 from tests.test_app import app
-from superset import db, sql_lab
+from superset import db, sql_lab, security_manager
 from superset.result_set import SupersetResultSet
 from superset.db_engine_specs.base import BaseEngineSpec
 from superset.errors import ErrorLevel, SupersetErrorType
@@ -163,19 +163,26 @@ def test_run_sync_query_dont_exist(setup_sqllab, ctas_method):
 
 @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
 @pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW])
-def test_run_sync_query_cta(setup_sqllab, ctas_method):
+@mock.patch("superset.utils.log.g", spec={})
+def test_run_sync_query_cta(
+    mock_g, setup_sqllab, ctas_method,
+):
     tmp_table_name = f"{TEST_SYNC}_{ctas_method.lower()}"
-    result = run_sql(QUERY, tmp_table=tmp_table_name, cta=True, ctas_method=ctas_method)
-    assert QueryStatus.SUCCESS == result["query"]["state"], result
-    assert cta_result(ctas_method) == (result["data"], result["columns"])
+    with app.test_request_context():
+        mock_g.return_value = 123
+        result = run_sql(
+            QUERY, tmp_table=tmp_table_name, cta=True, ctas_method=ctas_method
+        )
+        assert QueryStatus.SUCCESS == result["query"]["state"], result
+        assert cta_result(ctas_method) == (result["data"], result["columns"])
 
-    # Check the data in the tmp table.
-    select_query = get_query_by_id(result["query"]["serverId"])
-    results = run_sql(select_query.select_sql)
-    assert QueryStatus.SUCCESS == results["status"], results
-    assert len(results["data"]) > 0
+        # Check the data in the tmp table.
+        select_query = get_query_by_id(result["query"]["serverId"])
+        results = run_sql(select_query.select_sql)
+        assert QueryStatus.SUCCESS == results["status"], results
+        assert len(results["data"]) > 0
 
-    delete_tmp_view_or_table(tmp_table_name, ctas_method)
+        delete_tmp_view_or_table(tmp_table_name, ctas_method)
 
 
 @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")