You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2020/11/18 19:54:10 UTC

[airflow] 01/01: Webserver: Further Sanitize values passed to origin param

This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch origin-fix
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit d6a1ae1a7ddb1bb67f1d35dc8a72d387575176a1
Author: Kaxil Naik <ka...@gmail.com>
AuthorDate: Wed Nov 18 19:52:14 2020 +0000

    Webserver: Further Sanitize values passed to origin param
    
    Follow-up of https://github.com/apache/airflow/pull/10334
---
 airflow/www/views.py    |  7 ++++++-
 tests/www/test_views.py | 26 +++++++++++++++++++++++++-
 2 files changed, 31 insertions(+), 2 deletions(-)

diff --git a/airflow/www/views.py b/airflow/www/views.py
index fc2d2d0..bc4ac40 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -28,7 +28,7 @@ from collections import defaultdict
 from datetime import datetime, timedelta
 from json import JSONDecodeError
 from typing import Dict, List, Optional, Tuple
-from urllib.parse import unquote, urlparse
+from urllib.parse import parse_qsl, unquote, urlencode, urlparse
 
 import lazy_object_proxy
 import nvd3
@@ -108,8 +108,13 @@ def get_safe_url(url):
     valid_schemes = ['http', 'https', '']
     valid_netlocs = [request.host, '']
 
+    # Remove single quotes
+    url = url.replace("'", "")
     parsed = urlparse(url)
 
+    query = parse_qsl(parsed.query)
+    url = parsed._replace(query=urlencode(query)).geturl()
+
     if parsed.scheme in valid_schemes and parsed.netloc in valid_netlocs:
         return url
 
diff --git a/tests/www/test_views.py b/tests/www/test_views.py
index c3e9963..39fe6d8 100644
--- a/tests/www/test_views.py
+++ b/tests/www/test_views.py
@@ -62,7 +62,7 @@ from airflow.utils.state import State
 from airflow.utils.timezone import datetime
 from airflow.utils.types import DagRunType
 from airflow.www import app as application
-from airflow.www.views import ConnectionModelView
+from airflow.www.views import ConnectionModelView, get_safe_url
 from tests.test_utils import fab_utils
 from tests.test_utils.asserts import assert_queries_count
 from tests.test_utils.config import conf_vars
@@ -2772,6 +2772,7 @@ class TestTriggerDag(TestBase):
         [
             ("javascript:alert(1)", "/home"),
             ("http://google.com", "/home"),
+            ("%2Ftree%3Fdag_id%3Dexample_bash_operator%27;alert(33)//", "/tree?dag_id=example_bash_operator"),
             ("%2Ftree%3Fdag_id%3Dexample_bash_operator", "/tree?dag_id=example_bash_operator"),
             ("%2Fgraph%3Fdag_id%3Dexample_bash_operator", "/graph?dag_id=example_bash_operator"),
         ]
@@ -3293,3 +3294,26 @@ class TestDecorators(TestBase):
         self.check_last_log(
             "example_bash_operator", event="clear", execution_date=self.EXAMPLE_DAG_DEFAULT_DATE
         )
+
+
+class TestHelperFunctions(unittest.TestCase):
+    @parameterized.expand(
+        [
+            # ("javascript:alert(1)", "/home"),
+            ("http://google.com", "/home"),
+            (
+                "http://localhost:8080/trigger?dag_id=test_dag&origin=%2Ftree%3Fdag_id%test_dag';alert(33)//",
+                "http://localhost:8080/trigger?dag_id=test_dag&origin=%2Ftree%3Fdag_id%25test_dag",
+            ),
+            (
+                "http://localhost:8080/trigger?dag_id=test_dag&origin=%2Ftree%3Fdag_id%test_dag",
+                "http://localhost:8080/trigger?dag_id=test_dag&origin=%2Ftree%3Fdag_id%25test_dag",
+            ),
+        ]
+    )
+    @mock.patch("airflow.www.views.url_for")
+    @mock.patch("airflow.www.views.request")
+    def test_get_safe_url(self, test_url, expected_url, mock_req, mock_url_for):
+        mock_req.host = 'localhost:8080'
+        mock_url_for.return_value = "/home"
+        self.assertEqual(get_safe_url(test_url), expected_url)