You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2021/03/19 22:03:58 UTC

[airflow] 01/03: Fix error when running tasks with Sentry integration enabled. (#13929)

This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 14f978e9f27c8e223dd2f8e0d121ef699d8da663
Author: Jun <Ju...@users.noreply.github.com>
AuthorDate: Sat Mar 20 05:40:22 2021 +0800

    Fix error when running tasks with Sentry integration enabled. (#13929)
    
    Co-authored-by: Ash Berlin-Taylor <as...@apache.org>
    (cherry picked from commit 0e8698d3edb3712eba0514a39d1d30fbfeeaec09)
---
 airflow/sentry.py           | 13 +++++++++---
 airflow/utils/session.py    | 21 +++++++++++-------
 tests/utils/test_session.py | 52 +++++++++++++++++++++++++++++++++++++++++++++
 3 files changed, 75 insertions(+), 11 deletions(-)

diff --git a/airflow/sentry.py b/airflow/sentry.py
index 8dc9091..62eac9a 100644
--- a/airflow/sentry.py
+++ b/airflow/sentry.py
@@ -21,7 +21,7 @@ import logging
 from functools import wraps
 
 from airflow.configuration import conf
-from airflow.utils.session import provide_session
+from airflow.utils.session import find_session_idx, provide_session
 from airflow.utils.state import State
 
 log = logging.getLogger(__name__)
@@ -149,14 +149,21 @@ if conf.getboolean("sentry", 'sentry_on', fallback=False):
 
         def enrich_errors(self, func):
             """Wrap TaskInstance._run_raw_task to support task specific tags and breadcrumbs."""
+            session_args_idx = find_session_idx(func)
 
             @wraps(func)
-            def wrapper(task_instance, *args, session=None, **kwargs):
+            def wrapper(task_instance, *args, **kwargs):
                 # Wrapping the _run_raw_task function with push_scope to contain
                 # tags and breadcrumbs to a specific Task Instance
+
+                try:
+                    session = kwargs.get('session', args[session_args_idx])
+                except IndexError:
+                    session = None
+
                 with sentry_sdk.push_scope():
                     try:
-                        return func(task_instance, *args, session=session, **kwargs)
+                        return func(task_instance, *args, **kwargs)
                     except Exception as e:
                         self.add_tagging(task_instance)
                         self.add_breadcrumbs(task_instance, session=session)
diff --git a/airflow/utils/session.py b/airflow/utils/session.py
index 4001a0f..f8b9bcd 100644
--- a/airflow/utils/session.py
+++ b/airflow/utils/session.py
@@ -40,6 +40,18 @@ def create_session():
 RT = TypeVar("RT")  # pylint: disable=invalid-name
 
 
+def find_session_idx(func: Callable[..., RT]) -> int:
+    """Find session index in function call parameter."""
+    func_params = signature(func).parameters
+    try:
+        # func_params is an ordered dict -- this is the "recommended" way of getting the position
+        session_args_idx = tuple(func_params).index("session")
+    except ValueError:
+        raise ValueError(f"Function {func.__qualname__} has no `session` argument") from None
+
+    return session_args_idx
+
+
 def provide_session(func: Callable[..., RT]) -> Callable[..., RT]:
     """
     Function decorator that provides a session if it isn't provided.
@@ -47,14 +59,7 @@ def provide_session(func: Callable[..., RT]) -> Callable[..., RT]:
     database transaction, you pass it to the function, if not this wrapper
     will create one and close it for you.
     """
-    func_params = signature(func).parameters
-    try:
-        # func_params is an ordered dict -- this is the "recommended" way of getting the position
-        session_args_idx = tuple(func_params).index("session")
-    except ValueError:
-        raise ValueError(f"Function {func.__qualname__} has no `session` argument") from None
-    # We don't need this anymore -- ensure we don't keep a reference to it by mistake
-    del func_params
+    session_args_idx = find_session_idx(func)
 
     @wraps(func)
     def wrapper(*args, **kwargs) -> RT:
diff --git a/tests/utils/test_session.py b/tests/utils/test_session.py
new file mode 100644
index 0000000..08f317f
--- /dev/null
+++ b/tests/utils/test_session.py
@@ -0,0 +1,52 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import pytest
+
+from airflow.utils.session import provide_session
+
+
+class TestSession:
+    def dummy_session(self, session=None):
+        return session
+
+    def test_raised_provide_session(self):
+        with pytest.raises(ValueError, match="Function .*dummy has no `session` argument"):
+
+            @provide_session
+            def dummy():
+                pass
+
+    def test_provide_session_without_args_and_kwargs(self):
+        assert self.dummy_session() is None
+
+        wrapper = provide_session(self.dummy_session)
+
+        assert wrapper() is not None
+
+    def test_provide_session_with_args(self):
+        wrapper = provide_session(self.dummy_session)
+
+        session = object()
+        assert wrapper(session) is session
+
+    def test_provide_session_with_kwargs(self):
+        wrapper = provide_session(self.dummy_session)
+
+        session = object()
+        assert wrapper(session=session) is session