You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2020/08/15 15:04:03 UTC

[airflow] branch v1-10-test updated: Webserver: Sanitize values passed to origin param (#10334)

This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/v1-10-test by this push:
     new 4f8343c  Webserver: Sanitize values passed to origin param (#10334)
4f8343c is described below

commit 4f8343cbc73bcd1fabb77fdad859001dad0971cd
Author: Kaxil Naik <ka...@gmail.com>
AuthorDate: Sat Aug 15 16:01:33 2020 +0100

    Webserver: Sanitize values passed to origin param (#10334)
    
    (cherry-picked from 5c2bb7b0b0e717b11f093910b443243330ad93ca)
---
 airflow/www/views.py         | 37 +++++++++++++++++++++++++++----------
 airflow/www_rbac/views.py    | 37 +++++++++++++++++++++++++++----------
 tests/www/test_views.py      | 23 +++++++++++++++++++++++
 tests/www_rbac/test_views.py | 16 ++++++++++++++++
 4 files changed, 93 insertions(+), 20 deletions(-)

diff --git a/airflow/www/views.py b/airflow/www/views.py
index b496e72..6087356 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -54,7 +54,7 @@ from past.builtins import basestring
 from pygments import highlight, lexers
 import six
 from pygments.formatters.html import HtmlFormatter
-from six.moves.urllib.parse import quote, unquote
+from six.moves.urllib.parse import quote, unquote, urlparse
 
 from sqlalchemy import or_, desc, and_, union_all
 from wtforms import (
@@ -328,6 +328,23 @@ def get_chart_height(dag):
     return 600 + len(dag.tasks) * 10
 
 
+def get_safe_url(url):
+    """Given a user-supplied URL, ensure it points to our web server"""
+    try:
+        valid_schemes = ['http', 'https', '']
+        valid_netlocs = [request.host, '']
+
+        parsed = urlparse(url)
+        if parsed.scheme in valid_schemes and parsed.netloc in valid_netlocs:
+            return url
+    except Exception as e:  # pylint: disable=broad-except
+        log.debug("Error validating value in origin parameter passed to URL: %s", url)
+        log.debug("Error: %s", e)
+        pass
+
+    return "/admin/"
+
+
 def get_date_time_num_runs_dag_runs_form_data(request, session, dag):
     dttm = request.args.get('execution_date')
     if dttm:
@@ -1108,7 +1125,7 @@ class Airflow(AirflowViewMixin, BaseView):
     def run(self):
         dag_id = request.form.get('dag_id')
         task_id = request.form.get('task_id')
-        origin = request.form.get('origin')
+        origin = get_safe_url(request.form.get('origin'))
 
         dag = dagbag.get_dag(dag_id)
         task = dag.get_task(task_id)
@@ -1179,7 +1196,7 @@ class Airflow(AirflowViewMixin, BaseView):
         from airflow.exceptions import DagNotFound, DagFileExists
 
         dag_id = request.values.get('dag_id')
-        origin = request.values.get('origin') or "/admin/"
+        origin = get_safe_url(request.values.get('origin'))
 
         try:
             delete_dag.delete_dag(dag_id)
@@ -1203,7 +1220,7 @@ class Airflow(AirflowViewMixin, BaseView):
     @provide_session
     def trigger(self, session=None):
         dag_id = request.values.get('dag_id')
-        origin = request.values.get('origin') or "/admin/"
+        origin = get_safe_url(request.values.get('origin'))
 
         if request.method == 'GET':
             return self.render(
@@ -1304,7 +1321,7 @@ class Airflow(AirflowViewMixin, BaseView):
     def clear(self):
         dag_id = request.form.get('dag_id')
         task_id = request.form.get('task_id')
-        origin = request.form.get('origin')
+        origin = get_safe_url(request.form.get('origin'))
         dag = dagbag.get_dag(dag_id)
 
         execution_date = request.form.get('execution_date')
@@ -1334,7 +1351,7 @@ class Airflow(AirflowViewMixin, BaseView):
     @wwwutils.notify_owner
     def dagrun_clear(self):
         dag_id = request.form.get('dag_id')
-        origin = request.form.get('origin')
+        origin = get_safe_url(request.form.get('origin'))
         execution_date = request.form.get('execution_date')
         confirmed = request.form.get('confirmed') == "true"
 
@@ -1437,7 +1454,7 @@ class Airflow(AirflowViewMixin, BaseView):
         dag_id = request.form.get('dag_id')
         execution_date = request.form.get('execution_date')
         confirmed = request.form.get('confirmed') == 'true'
-        origin = request.form.get('origin')
+        origin = get_safe_url(request.form.get('origin'))
         return self._mark_dagrun_state_as_failed(dag_id, execution_date,
                                                  confirmed, origin)
 
@@ -1449,7 +1466,7 @@ class Airflow(AirflowViewMixin, BaseView):
         dag_id = request.form.get('dag_id')
         execution_date = request.form.get('execution_date')
         confirmed = request.form.get('confirmed') == 'true'
-        origin = request.form.get('origin')
+        origin = get_safe_url(request.form.get('origin'))
         return self._mark_dagrun_state_as_success(dag_id, execution_date,
                                                   confirmed, origin)
 
@@ -1502,7 +1519,7 @@ class Airflow(AirflowViewMixin, BaseView):
     def failed(self):
         dag_id = request.form.get('dag_id')
         task_id = request.form.get('task_id')
-        origin = request.form.get('origin')
+        origin = get_safe_url(request.form.get('origin'))
         execution_date = request.form.get('execution_date')
 
         confirmed = request.form.get('confirmed') == "true"
@@ -1522,7 +1539,7 @@ class Airflow(AirflowViewMixin, BaseView):
     def success(self):
         dag_id = request.form.get('dag_id')
         task_id = request.form.get('task_id')
-        origin = request.form.get('origin')
+        origin = get_safe_url(request.form.get('origin'))
         execution_date = request.form.get('execution_date')
 
         confirmed = request.form.get('confirmed') == "true"
diff --git a/airflow/www_rbac/views.py b/airflow/www_rbac/views.py
index f098b25..9d46d03 100644
--- a/airflow/www_rbac/views.py
+++ b/airflow/www_rbac/views.py
@@ -31,7 +31,7 @@ from datetime import timedelta
 from urllib.parse import unquote
 
 import six
-from six.moves.urllib.parse import quote
+from six.moves.urllib.parse import quote, urlparse
 
 import pendulum
 import sqlalchemy as sqla
@@ -89,6 +89,23 @@ else:
     dagbag = models.DagBag(os.devnull, include_examples=False)
 
 
+def get_safe_url(url):
+    """Given a user-supplied URL, ensure it points to our web server"""
+    try:
+        valid_schemes = ['http', 'https', '']
+        valid_netlocs = [request.host, '']
+
+        parsed = urlparse(url)
+        if parsed.scheme in valid_schemes and parsed.netloc in valid_netlocs:
+            return url
+    except Exception as e:  # pylint: disable=broad-except
+        logging.debug("Error validating value in origin parameter passed to URL: %s", url)
+        logging.debug("Error: %s", e)
+        pass
+
+    return url_for('Airflow.index')
+
+
 def get_date_time_num_runs_dag_runs_form_data(request, session, dag):
     dttm = request.args.get('execution_date')
     if dttm:
@@ -930,7 +947,7 @@ class Airflow(AirflowBaseView):
     def run(self):
         dag_id = request.form.get('dag_id')
         task_id = request.form.get('task_id')
-        origin = request.form.get('origin')
+        origin = get_safe_url(request.form.get('origin'))
         dag = dagbag.get_dag(dag_id)
         task = dag.get_task(task_id)
 
@@ -1000,7 +1017,7 @@ class Airflow(AirflowBaseView):
         from airflow.exceptions import DagNotFound, DagFileExists
 
         dag_id = request.values.get('dag_id')
-        origin = request.values.get('origin') or url_for('Airflow.index')
+        origin = get_safe_url(request.values.get('origin'))
 
         try:
             delete_dag.delete_dag(dag_id)
@@ -1027,7 +1044,7 @@ class Airflow(AirflowBaseView):
     def trigger(self, session=None):
 
         dag_id = request.values.get('dag_id')
-        origin = request.values.get('origin') or url_for('Airflow.index')
+        origin = get_safe_url(request.values.get('origin'))
 
         if request.method == 'GET':
             return self.render_template(
@@ -1128,7 +1145,7 @@ class Airflow(AirflowBaseView):
     def clear(self):
         dag_id = request.form.get('dag_id')
         task_id = request.form.get('task_id')
-        origin = request.form.get('origin')
+        origin = get_safe_url(request.form.get('origin'))
         dag = dagbag.get_dag(dag_id)
 
         execution_date = request.form.get('execution_date')
@@ -1158,7 +1175,7 @@ class Airflow(AirflowBaseView):
     @action_logging
     def dagrun_clear(self):
         dag_id = request.form.get('dag_id')
-        origin = request.form.get('origin')
+        origin = get_safe_url(request.form.get('origin'))
         execution_date = request.form.get('execution_date')
         confirmed = request.form.get('confirmed') == "true"
 
@@ -1280,7 +1297,7 @@ class Airflow(AirflowBaseView):
         dag_id = request.form.get('dag_id')
         execution_date = request.form.get('execution_date')
         confirmed = request.form.get('confirmed') == 'true'
-        origin = request.form.get('origin')
+        origin = get_safe_url(request.form.get('origin'))
         return self._mark_dagrun_state_as_failed(dag_id, execution_date,
                                                  confirmed, origin)
 
@@ -1292,7 +1309,7 @@ class Airflow(AirflowBaseView):
         dag_id = request.form.get('dag_id')
         execution_date = request.form.get('execution_date')
         confirmed = request.form.get('confirmed') == 'true'
-        origin = request.form.get('origin')
+        origin = get_safe_url(request.form.get('origin'))
         return self._mark_dagrun_state_as_success(dag_id, execution_date,
                                                   confirmed, origin)
 
@@ -1345,7 +1362,7 @@ class Airflow(AirflowBaseView):
     def failed(self):
         dag_id = request.form.get('dag_id')
         task_id = request.form.get('task_id')
-        origin = request.form.get('origin')
+        origin = get_safe_url(request.form.get('origin'))
         execution_date = request.form.get('execution_date')
 
         confirmed = request.form.get('confirmed') == "true"
@@ -1365,7 +1382,7 @@ class Airflow(AirflowBaseView):
     def success(self):
         dag_id = request.form.get('dag_id')
         task_id = request.form.get('task_id')
-        origin = request.form.get('origin')
+        origin = get_safe_url(request.form.get('origin'))
         execution_date = request.form.get('execution_date')
 
         confirmed = request.form.get('confirmed') == "true"
diff --git a/tests/www/test_views.py b/tests/www/test_views.py
index ac71ebb..438830c 100644
--- a/tests/www/test_views.py
+++ b/tests/www/test_views.py
@@ -37,6 +37,7 @@ from flask._compat import PY2
 from airflow.operators.bash_operator import BashOperator
 from airflow.utils import timezone
 from airflow.utils.db import create_session
+from parameterized import parameterized
 from tests.compat import mock
 
 from six.moves.urllib.parse import quote_plus
@@ -1115,6 +1116,28 @@ class TestTriggerDag(unittest.TestCase):
             'Triggered example_bash_operator, it should start any moment now.',
             response.data.decode('utf-8'))
 
+    @parameterized.expand([
+        ("javascript:alert(1)", "/admin/"),
+        ("http://google.com", "/admin/"),
+        (
+            "%2Fadmin%2Fairflow%2Ftree%3Fdag_id%3Dexample_bash_operator&dag_id=example_bash_operator",
+            "/admin/airflow/tree?dag_id=example_bash_operator"
+        ),
+        (
+            "%2Fadmin%2Fairflow%2Fgraph%3Fdag_id%3Dexample_bash_operator&dag_id=example_bash_operator",
+            "/admin/airflow/graph?dag_id=example_bash_operator"
+        ),
+        ("", ""),
+    ])
+    def test_trigger_dag_form_origin_url(self, test_origin, expected_origin):
+        test_dag_id = "example_bash_operator"
+        response = self.app.get(
+            '/admin/airflow/trigger?dag_id={}&origin={}'.format(test_dag_id, test_origin))
+        self.assertIn(
+            '<button class="btn" onclick="location.href = \'{}\'; return false">'.format(
+                expected_origin),
+            response.data.decode('utf-8'))
+
 
 class HelpersTest(unittest.TestCase):
     @classmethod
diff --git a/tests/www_rbac/test_views.py b/tests/www_rbac/test_views.py
index 33a8338..4e06b57 100644
--- a/tests/www_rbac/test_views.py
+++ b/tests/www_rbac/test_views.py
@@ -2244,6 +2244,22 @@ class TestTriggerDag(TestBase):
         self.check_content_in_response(
             'Triggered example_bash_operator, it should start any moment now.', response)
 
+    @parameterized.expand([
+        ("javascript:alert(1)", "/home"),
+        ("http://google.com", "/home"),
+        ("%2Ftree%3Fdag_id%3Dexample_bash_operator", "/tree?dag_id=example_bash_operator"),
+        ("%2Fgraph%3Fdag_id%3Dexample_bash_operator", "/graph?dag_id=example_bash_operator"),
+        ("", ""),
+    ])
+    def test_trigger_dag_form_origin_url(self, test_origin, expected_origin):
+        test_dag_id = "example_bash_operator"
+
+        resp = self.client.get('trigger?dag_id={}&origin={}'.format(test_dag_id, test_origin))
+        self.check_content_in_response(
+            '<button class="btn" onclick="location.href = \'{}\'; return false">'.format(
+                expected_origin),
+            resp)
+
     @mock.patch('airflow.www_rbac.views.dagbag.get_dag')
     def test_trigger_endpoint_uses_existing_dagbag(self, mock_get_dag):
         """