You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by yj...@apache.org on 2020/10/14 17:44:23 UTC

[incubator-superset] branch master updated: refactor: use contextmanager for event_logger decorators (#11222)

This is an automated email from the ASF dual-hosted git repository.

yjc pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 634676d  refactor: use contextmanager for event_logger decorators (#11222)
634676d is described below

commit 634676d467a57701d80df09a97ac35906001b34f
Author: Jesse Yang <je...@airbnb.com>
AuthorDate: Wed Oct 14 10:44:06 2020 -0700

    refactor: use contextmanager for event_logger decorators (#11222)
---
 superset/__init__.py        |   1 -
 superset/utils/log.py       | 109 ++++++++++++++++++++++++++------------------
 superset/views/core.py      |  16 +++----
 tests/event_logger_tests.py |  41 ++++++++++++++++-
 4 files changed, 113 insertions(+), 54 deletions(-)

diff --git a/superset/__init__.py b/superset/__init__.py
index 2d76afd..4dbee10 100644
--- a/superset/__init__.py
+++ b/superset/__init__.py
@@ -33,7 +33,6 @@ from superset.extensions import (
     talisman,
 )
 from superset.security import SupersetSecurityManager
-from superset.utils.log import DBEventLogger, get_event_logger_from_cfg_value
 
 #  All of the fields located here should be considered legacy. The correct way
 #  to declare "global" dependencies is to define it in extensions.py,
diff --git a/superset/utils/log.py b/superset/utils/log.py
index cea161e..4380006 100644
--- a/superset/utils/log.py
+++ b/superset/utils/log.py
@@ -19,9 +19,10 @@ import inspect
 import json
 import logging
 import textwrap
+import time
 from abc import ABC, abstractmethod
-from datetime import datetime
-from typing import Any, Callable, cast, Optional, Type
+from contextlib import contextmanager
+from typing import Any, Callable, cast, Iterator, Optional, Type
 
 from flask import current_app, g, request
 from sqlalchemy.exc import SQLAlchemyError
@@ -36,58 +37,76 @@ class AbstractEventLogger(ABC):
     ) -> None:
         pass
 
-    def log_this(self, f: Callable[..., Any]) -> Callable[..., Any]:
+    @contextmanager
+    def log_context(self, action: str) -> Iterator[Callable[..., None]]:
+        """
+        Log an event while reading information from the request context.
+        `kwargs` will be appended directly to the log payload.
+        """
         from superset.views.core import get_form_data
 
-        @functools.wraps(f)
-        def wrapper(*args: Any, **kwargs: Any) -> Any:
-            user_id = None
-            if hasattr(g, "user") and g.user:
-                user_id = g.user.get_id()
-            payload = request.form.to_dict() or {}
+        start_time = time.time()
+        referrer = request.referrer[:1000] if request.referrer else None
+        user_id = g.user.get_id() if hasattr(g, "user") and g.user else None
+        payload = request.form.to_dict() or {}
+        # request parameters can overwrite post body
+        payload.update(request.args.to_dict())
 
-            # request parameters can overwrite post body
-            request_params = request.args.to_dict()
-            payload.update(request_params)
-            payload.update(kwargs)
+        # yield a helper to update additional kwargs
+        yield lambda **kwargs: payload.update(kwargs)
 
-            dashboard_id = payload.get("dashboard_id")
+        dashboard_id = payload.get("dashboard_id")
 
-            if "form_data" in payload:
-                form_data, _ = get_form_data()
-                payload["form_data"] = form_data
-                slice_id = form_data.get("slice_id")
-            else:
-                slice_id = payload.get("slice_id")
+        if "form_data" in payload:
+            form_data, _ = get_form_data()
+            payload["form_data"] = form_data
+            slice_id = form_data.get("slice_id")
+        else:
+            slice_id = payload.get("slice_id")
 
-            try:
-                slice_id = int(slice_id)  # type: ignore
-            except (TypeError, ValueError):
-                slice_id = 0
+        try:
+            slice_id = int(slice_id)  # type: ignore
+        except (TypeError, ValueError):
+            slice_id = 0
 
-            self.stats_logger.incr(f.__name__)
-            start_dttm = datetime.now()
-            value = f(*args, **kwargs)
-            duration_ms = (datetime.now() - start_dttm).total_seconds() * 1000
+        self.stats_logger.incr(action)
 
-            # bulk insert
-            try:
-                explode_by = payload.get("explode")
-                records = json.loads(payload.get(explode_by))  # type: ignore
-            except Exception:  # pylint: disable=broad-except
-                records = [payload]
+        # bulk insert
+        try:
+            explode_by = payload.get("explode")
+            records = json.loads(payload.get(explode_by))  # type: ignore
+        except Exception:  # pylint: disable=broad-except
+            records = [payload]
+
+        self.log(
+            user_id,
+            action,
+            records=records,
+            dashboard_id=dashboard_id,
+            slice_id=slice_id,
+            duration_ms=round((time.time() - start_time) * 1000),
+            referrer=referrer,
+        )
+
+    def log_this(self, f: Callable[..., Any]) -> Callable[..., Any]:
+        @functools.wraps(f)
+        def wrapper(*args: Any, **kwargs: Any) -> Any:
+            with self.log_context(f.__name__) as log:
+                value = f(*args, **kwargs)
+                log(**kwargs)
+            return value
 
-            referrer = request.referrer[:1000] if request.referrer else None
+        return wrapper
 
-            self.log(
-                user_id,
-                f.__name__,
-                records=records,
-                dashboard_id=dashboard_id,
-                slice_id=slice_id,
-                duration_ms=duration_ms,
-                referrer=referrer,
-            )
+    def log_manually(self, f: Callable[..., Any]) -> Callable[..., Any]:
+        """Allow a function to manually update"""
+
+        @functools.wraps(f)
+        def wrapper(*args: Any, **kwargs: Any) -> Any:
+            with self.log_context(f.__name__) as log:
+                # updated_log_payload should be either the last positional
+                # argument or one of the named arguments of the decorated function
+                value = f(*args, update_log_payload=log, **kwargs)
             return value
 
         return wrapper
@@ -141,6 +160,8 @@ def get_event_logger_from_cfg_value(cfg_value: Any) -> AbstractEventLogger:
 
 
 class DBEventLogger(AbstractEventLogger):
+    """Event logger that commits logs to Superset DB"""
+
     def log(  # pylint: disable=too-many-locals
         self, user_id: Optional[int], action: str, *args: Any, **kwargs: Any
     ) -> None:
diff --git a/superset/views/core.py b/superset/views/core.py
index 561aed8..f43a603 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -19,7 +19,7 @@ import logging
 import re
 from contextlib import closing
 from datetime import datetime
-from typing import Any, cast, Dict, List, Optional, Union
+from typing import Any, Callable, cast, Dict, List, Optional, Union
 from urllib import parse
 
 import backoff
@@ -1602,8 +1602,13 @@ class Superset(BaseSupersetView):  # pylint: disable=too-many-public-methods
 
     @has_access
     @expose("/dashboard/<dashboard_id_or_slug>/")
+    @event_logger.log_manually
     def dashboard(  # pylint: disable=too-many-locals
-        self, dashboard_id_or_slug: str
+        self,
+        dashboard_id_or_slug: str,
+        # this parameter is added by `log_manually`,
+        # set a default value to appease pylint
+        update_log_payload: Callable[..., None] = lambda **kwargs: None,
     ) -> FlaskResponse:
         """Server side rendering for a dashboard"""
         session = db.session()
@@ -1652,12 +1657,7 @@ class Superset(BaseSupersetView):  # pylint: disable=too-many-public-methods
             request.args.get(utils.ReservedUrlParameters.EDIT_MODE.value) == "true"
         )
 
-        # Hack to log the dashboard_id properly, even when getting a slug
-        @event_logger.log_this
-        def dashboard(**_: Any) -> None:
-            pass
-
-        dashboard(
+        update_log_payload(
             dashboard_id=dash.id,
             dashboard_version="v2",
             dash_edit_perm=dash_edit_perm,
diff --git a/tests/event_logger_tests.py b/tests/event_logger_tests.py
index a7a1f22..23f3741 100644
--- a/tests/event_logger_tests.py
+++ b/tests/event_logger_tests.py
@@ -15,9 +15,17 @@
 # specific language governing permissions and limitations
 # under the License.
 import logging
+import time
 import unittest
+from datetime import datetime
+from unittest.mock import patch
 
-from superset.utils.log import DBEventLogger, get_event_logger_from_cfg_value
+from superset.utils.log import (
+    AbstractEventLogger,
+    DBEventLogger,
+    get_event_logger_from_cfg_value,
+)
+from tests.test_app import app
 
 
 class TestEventLogger(unittest.TestCase):
@@ -42,3 +50,34 @@ class TestEventLogger(unittest.TestCase):
         # test that assignment of non AbstractEventLogger derived type raises TypeError
         with self.assertRaises(TypeError):
             get_event_logger_from_cfg_value(logging.getLogger())
+
+    @patch.object(DBEventLogger, "log")
+    def test_log_this_decorator(self, mock_log):
+        logger = DBEventLogger()
+
+        @logger.log_this
+        def test_func():
+            time.sleep(0.05)
+            return 1
+
+        with app.test_request_context():
+            result = test_func()
+            self.assertEqual(result, 1)
+            assert mock_log.call_args[1]["duration_ms"] >= 50
+
+    @patch.object(DBEventLogger, "log")
+    def test_log_manually_decorator(self, mock_log):
+        logger = DBEventLogger()
+
+        @logger.log_manually
+        def test_func(arg1, update_log_payload, karg1=1):
+            time.sleep(0.1)
+            update_log_payload(foo="bar")
+            return arg1 * karg1
+
+        with app.test_request_context():
+            result = test_func(1, karg1=2)  # pylint: disable=no-value-for-parameter
+            self.assertEqual(result, 2)
+            # should contain only manual payload
+            self.assertEqual(mock_log.call_args[1]["records"], [{"foo": "bar"}])
+            assert mock_log.call_args[1]["duration_ms"] >= 100