You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2020/08/15 15:05:11 UTC
[airflow] 02/02: Webserver: Sanitize values passed to origin param
(#10334)
This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit f197534cc7cb4e9ed712f7fa9bb19e1931b017e6
Author: Kaxil Naik <ka...@gmail.com>
AuthorDate: Sat Aug 15 16:01:33 2020 +0100
Webserver: Sanitize values passed to origin param (#10334)
(cherry-picked from 5c2bb7b0b0e717b11f093910b443243330ad93ca)
---
airflow/www/views.py | 37 +++++++++++++++++++++++++++----------
airflow/www_rbac/views.py | 37 +++++++++++++++++++++++++++----------
tests/www/test_views.py | 23 +++++++++++++++++++++++
tests/www_rbac/test_views.py | 16 ++++++++++++++++
4 files changed, 93 insertions(+), 20 deletions(-)
diff --git a/airflow/www/views.py b/airflow/www/views.py
index b496e72..6087356 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -54,7 +54,7 @@ from past.builtins import basestring
from pygments import highlight, lexers
import six
from pygments.formatters.html import HtmlFormatter
-from six.moves.urllib.parse import quote, unquote
+from six.moves.urllib.parse import quote, unquote, urlparse
from sqlalchemy import or_, desc, and_, union_all
from wtforms import (
@@ -328,6 +328,23 @@ def get_chart_height(dag):
return 600 + len(dag.tasks) * 10
+def get_safe_url(url):
+ """Given a user-supplied URL, ensure it points to our web server"""
+ try:
+ valid_schemes = ['http', 'https', '']
+ valid_netlocs = [request.host, '']
+
+ parsed = urlparse(url)
+ if parsed.scheme in valid_schemes and parsed.netloc in valid_netlocs:
+ return url
+ except Exception as e: # pylint: disable=broad-except
+ log.debug("Error validating value in origin parameter passed to URL: %s", url)
+ log.debug("Error: %s", e)
+ pass
+
+ return "/admin/"
+
+
def get_date_time_num_runs_dag_runs_form_data(request, session, dag):
dttm = request.args.get('execution_date')
if dttm:
@@ -1108,7 +1125,7 @@ class Airflow(AirflowViewMixin, BaseView):
def run(self):
dag_id = request.form.get('dag_id')
task_id = request.form.get('task_id')
- origin = request.form.get('origin')
+ origin = get_safe_url(request.form.get('origin'))
dag = dagbag.get_dag(dag_id)
task = dag.get_task(task_id)
@@ -1179,7 +1196,7 @@ class Airflow(AirflowViewMixin, BaseView):
from airflow.exceptions import DagNotFound, DagFileExists
dag_id = request.values.get('dag_id')
- origin = request.values.get('origin') or "/admin/"
+ origin = get_safe_url(request.values.get('origin'))
try:
delete_dag.delete_dag(dag_id)
@@ -1203,7 +1220,7 @@ class Airflow(AirflowViewMixin, BaseView):
@provide_session
def trigger(self, session=None):
dag_id = request.values.get('dag_id')
- origin = request.values.get('origin') or "/admin/"
+ origin = get_safe_url(request.values.get('origin'))
if request.method == 'GET':
return self.render(
@@ -1304,7 +1321,7 @@ class Airflow(AirflowViewMixin, BaseView):
def clear(self):
dag_id = request.form.get('dag_id')
task_id = request.form.get('task_id')
- origin = request.form.get('origin')
+ origin = get_safe_url(request.form.get('origin'))
dag = dagbag.get_dag(dag_id)
execution_date = request.form.get('execution_date')
@@ -1334,7 +1351,7 @@ class Airflow(AirflowViewMixin, BaseView):
@wwwutils.notify_owner
def dagrun_clear(self):
dag_id = request.form.get('dag_id')
- origin = request.form.get('origin')
+ origin = get_safe_url(request.form.get('origin'))
execution_date = request.form.get('execution_date')
confirmed = request.form.get('confirmed') == "true"
@@ -1437,7 +1454,7 @@ class Airflow(AirflowViewMixin, BaseView):
dag_id = request.form.get('dag_id')
execution_date = request.form.get('execution_date')
confirmed = request.form.get('confirmed') == 'true'
- origin = request.form.get('origin')
+ origin = get_safe_url(request.form.get('origin'))
return self._mark_dagrun_state_as_failed(dag_id, execution_date,
confirmed, origin)
@@ -1449,7 +1466,7 @@ class Airflow(AirflowViewMixin, BaseView):
dag_id = request.form.get('dag_id')
execution_date = request.form.get('execution_date')
confirmed = request.form.get('confirmed') == 'true'
- origin = request.form.get('origin')
+ origin = get_safe_url(request.form.get('origin'))
return self._mark_dagrun_state_as_success(dag_id, execution_date,
confirmed, origin)
@@ -1502,7 +1519,7 @@ class Airflow(AirflowViewMixin, BaseView):
def failed(self):
dag_id = request.form.get('dag_id')
task_id = request.form.get('task_id')
- origin = request.form.get('origin')
+ origin = get_safe_url(request.form.get('origin'))
execution_date = request.form.get('execution_date')
confirmed = request.form.get('confirmed') == "true"
@@ -1522,7 +1539,7 @@ class Airflow(AirflowViewMixin, BaseView):
def success(self):
dag_id = request.form.get('dag_id')
task_id = request.form.get('task_id')
- origin = request.form.get('origin')
+ origin = get_safe_url(request.form.get('origin'))
execution_date = request.form.get('execution_date')
confirmed = request.form.get('confirmed') == "true"
diff --git a/airflow/www_rbac/views.py b/airflow/www_rbac/views.py
index f098b25..9d46d03 100644
--- a/airflow/www_rbac/views.py
+++ b/airflow/www_rbac/views.py
@@ -31,7 +31,7 @@ from datetime import timedelta
from urllib.parse import unquote
import six
-from six.moves.urllib.parse import quote
+from six.moves.urllib.parse import quote, urlparse
import pendulum
import sqlalchemy as sqla
@@ -89,6 +89,23 @@ else:
dagbag = models.DagBag(os.devnull, include_examples=False)
+def get_safe_url(url):
+ """Given a user-supplied URL, ensure it points to our web server"""
+ try:
+ valid_schemes = ['http', 'https', '']
+ valid_netlocs = [request.host, '']
+
+ parsed = urlparse(url)
+ if parsed.scheme in valid_schemes and parsed.netloc in valid_netlocs:
+ return url
+ except Exception as e: # pylint: disable=broad-except
+ logging.debug("Error validating value in origin parameter passed to URL: %s", url)
+ logging.debug("Error: %s", e)
+ pass
+
+ return url_for('Airflow.index')
+
+
def get_date_time_num_runs_dag_runs_form_data(request, session, dag):
dttm = request.args.get('execution_date')
if dttm:
@@ -930,7 +947,7 @@ class Airflow(AirflowBaseView):
def run(self):
dag_id = request.form.get('dag_id')
task_id = request.form.get('task_id')
- origin = request.form.get('origin')
+ origin = get_safe_url(request.form.get('origin'))
dag = dagbag.get_dag(dag_id)
task = dag.get_task(task_id)
@@ -1000,7 +1017,7 @@ class Airflow(AirflowBaseView):
from airflow.exceptions import DagNotFound, DagFileExists
dag_id = request.values.get('dag_id')
- origin = request.values.get('origin') or url_for('Airflow.index')
+ origin = get_safe_url(request.values.get('origin'))
try:
delete_dag.delete_dag(dag_id)
@@ -1027,7 +1044,7 @@ class Airflow(AirflowBaseView):
def trigger(self, session=None):
dag_id = request.values.get('dag_id')
- origin = request.values.get('origin') or url_for('Airflow.index')
+ origin = get_safe_url(request.values.get('origin'))
if request.method == 'GET':
return self.render_template(
@@ -1128,7 +1145,7 @@ class Airflow(AirflowBaseView):
def clear(self):
dag_id = request.form.get('dag_id')
task_id = request.form.get('task_id')
- origin = request.form.get('origin')
+ origin = get_safe_url(request.form.get('origin'))
dag = dagbag.get_dag(dag_id)
execution_date = request.form.get('execution_date')
@@ -1158,7 +1175,7 @@ class Airflow(AirflowBaseView):
@action_logging
def dagrun_clear(self):
dag_id = request.form.get('dag_id')
- origin = request.form.get('origin')
+ origin = get_safe_url(request.form.get('origin'))
execution_date = request.form.get('execution_date')
confirmed = request.form.get('confirmed') == "true"
@@ -1280,7 +1297,7 @@ class Airflow(AirflowBaseView):
dag_id = request.form.get('dag_id')
execution_date = request.form.get('execution_date')
confirmed = request.form.get('confirmed') == 'true'
- origin = request.form.get('origin')
+ origin = get_safe_url(request.form.get('origin'))
return self._mark_dagrun_state_as_failed(dag_id, execution_date,
confirmed, origin)
@@ -1292,7 +1309,7 @@ class Airflow(AirflowBaseView):
dag_id = request.form.get('dag_id')
execution_date = request.form.get('execution_date')
confirmed = request.form.get('confirmed') == 'true'
- origin = request.form.get('origin')
+ origin = get_safe_url(request.form.get('origin'))
return self._mark_dagrun_state_as_success(dag_id, execution_date,
confirmed, origin)
@@ -1345,7 +1362,7 @@ class Airflow(AirflowBaseView):
def failed(self):
dag_id = request.form.get('dag_id')
task_id = request.form.get('task_id')
- origin = request.form.get('origin')
+ origin = get_safe_url(request.form.get('origin'))
execution_date = request.form.get('execution_date')
confirmed = request.form.get('confirmed') == "true"
@@ -1365,7 +1382,7 @@ class Airflow(AirflowBaseView):
def success(self):
dag_id = request.form.get('dag_id')
task_id = request.form.get('task_id')
- origin = request.form.get('origin')
+ origin = get_safe_url(request.form.get('origin'))
execution_date = request.form.get('execution_date')
confirmed = request.form.get('confirmed') == "true"
diff --git a/tests/www/test_views.py b/tests/www/test_views.py
index ac71ebb..438830c 100644
--- a/tests/www/test_views.py
+++ b/tests/www/test_views.py
@@ -37,6 +37,7 @@ from flask._compat import PY2
from airflow.operators.bash_operator import BashOperator
from airflow.utils import timezone
from airflow.utils.db import create_session
+from parameterized import parameterized
from tests.compat import mock
from six.moves.urllib.parse import quote_plus
@@ -1115,6 +1116,28 @@ class TestTriggerDag(unittest.TestCase):
'Triggered example_bash_operator, it should start any moment now.',
response.data.decode('utf-8'))
+ @parameterized.expand([
+ ("javascript:alert(1)", "/admin/"),
+ ("http://google.com", "/admin/"),
+ (
+ "%2Fadmin%2Fairflow%2Ftree%3Fdag_id%3Dexample_bash_operator&dag_id=example_bash_operator",
+ "/admin/airflow/tree?dag_id=example_bash_operator"
+ ),
+ (
+ "%2Fadmin%2Fairflow%2Fgraph%3Fdag_id%3Dexample_bash_operator&dag_id=example_bash_operator",
+ "/admin/airflow/graph?dag_id=example_bash_operator"
+ ),
+ ("", ""),
+ ])
+ def test_trigger_dag_form_origin_url(self, test_origin, expected_origin):
+ test_dag_id = "example_bash_operator"
+ response = self.app.get(
+ '/admin/airflow/trigger?dag_id={}&origin={}'.format(test_dag_id, test_origin))
+ self.assertIn(
+ '<button class="btn" onclick="location.href = \'{}\'; return false">'.format(
+ expected_origin),
+ response.data.decode('utf-8'))
+
class HelpersTest(unittest.TestCase):
@classmethod
diff --git a/tests/www_rbac/test_views.py b/tests/www_rbac/test_views.py
index 33a8338..4e06b57 100644
--- a/tests/www_rbac/test_views.py
+++ b/tests/www_rbac/test_views.py
@@ -2244,6 +2244,22 @@ class TestTriggerDag(TestBase):
self.check_content_in_response(
'Triggered example_bash_operator, it should start any moment now.', response)
+ @parameterized.expand([
+ ("javascript:alert(1)", "/home"),
+ ("http://google.com", "/home"),
+ ("%2Ftree%3Fdag_id%3Dexample_bash_operator", "/tree?dag_id=example_bash_operator"),
+ ("%2Fgraph%3Fdag_id%3Dexample_bash_operator", "/graph?dag_id=example_bash_operator"),
+ ("", ""),
+ ])
+ def test_trigger_dag_form_origin_url(self, test_origin, expected_origin):
+ test_dag_id = "example_bash_operator"
+
+ resp = self.client.get('trigger?dag_id={}&origin={}'.format(test_dag_id, test_origin))
+ self.check_content_in_response(
+ '<button class="btn" onclick="location.href = \'{}\'; return false">'.format(
+ expected_origin),
+ resp)
+
@mock.patch('airflow.www_rbac.views.dagbag.get_dag')
def test_trigger_endpoint_uses_existing_dagbag(self, mock_get_dag):
"""