You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by je...@apache.org on 2022/09/23 03:19:33 UTC
[airflow] 12/13: Correctly set json_provider_class on Flask app so it uses our encoder (#26554)
This is an automated email from the ASF dual-hosted git repository.
jedcunningham pushed a commit to branch v2-4-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 5ae2ae531bba1d15975ba501c13f130959e8dd38
Author: Ash Berlin-Taylor <as...@apache.org>
AuthorDate: Wed Sep 21 22:12:39 2022 +0100
Correctly set json_provider_class on Flask app so it uses our encoder (#26554)
Setting `json_provider_class` where we did had no effect, as it turns
out `Flask()` sets `self.json = self.json_provider_class(self)`, so we
were setting it too late.
(cherry picked from commit 378dfbe2fe266f17859dbabd34b9bc8cd5c904ab)
---
airflow/utils/json.py | 20 ++++++++++++++++++--
airflow/www/app.py | 5 +++--
airflow/www/utils.py | 10 +---------
airflow/www/views.py | 33 +++++++++++++++++----------------
tests/www/test_app.py | 10 ++++++++++
5 files changed, 49 insertions(+), 29 deletions(-)
diff --git a/airflow/utils/json.py b/airflow/utils/json.py
index fcc4eedd6e..ff11097824 100644
--- a/airflow/utils/json.py
+++ b/airflow/utils/json.py
@@ -17,11 +17,12 @@
# under the License.
from __future__ import annotations
+import json
import logging
from datetime import date, datetime
from decimal import Decimal
-from flask.json import JSONEncoder
+from flask.json.provider import JSONProvider
from airflow.utils.timezone import convert_to_utc, is_naive
@@ -40,7 +41,7 @@ except ImportError:
log = logging.getLogger(__name__)
-class AirflowJsonEncoder(JSONEncoder):
+class AirflowJsonEncoder(json.JSONEncoder):
"""Custom Airflow json encoder implementation."""
def __init__(self, *args, **kwargs):
@@ -107,3 +108,18 @@ class AirflowJsonEncoder(JSONEncoder):
return {}
raise TypeError(f"Object of type '{obj.__class__.__name__}' is not JSON serializable")
+
+
+class AirflowJsonProvider(JSONProvider):
+ """JSON Provider for Flask app to use AirflowJsonEncoder."""
+
+ ensure_ascii: bool = True
+ sort_keys: bool = True
+
+ def dumps(self, obj, **kwargs):
+ kwargs.setdefault('ensure_ascii', self.ensure_ascii)
+ kwargs.setdefault('sort_keys', self.sort_keys)
+ return json.dumps(obj, **kwargs, cls=AirflowJsonEncoder)
+
+ def loads(self, s: str | bytes, **kwargs):
+ return json.loads(s, **kwargs)
diff --git a/airflow/www/app.py b/airflow/www/app.py
index d40f3badb8..b67314c99a 100644
--- a/airflow/www/app.py
+++ b/airflow/www/app.py
@@ -32,7 +32,7 @@ from airflow.configuration import conf
from airflow.exceptions import AirflowConfigException, RemovedInAirflow3Warning
from airflow.logging_config import configure_logging
from airflow.models import import_all_models
-from airflow.utils.json import AirflowJsonEncoder
+from airflow.utils.json import AirflowJsonProvider
from airflow.www.extensions.init_appbuilder import init_appbuilder
from airflow.www.extensions.init_appbuilder_links import init_appbuilder_links
from airflow.www.extensions.init_dagbag import init_dagbag
@@ -109,7 +109,8 @@ def create_app(config=None, testing=False):
flask_app.config['SQLALCHEMY_ENGINE_OPTIONS'] = settings.prepare_engine_args()
# Configure the JSON encoder used by `|tojson` filter from Flask
- flask_app.json_provider_class = AirflowJsonEncoder
+ flask_app.json_provider_class = AirflowJsonProvider
+ flask_app.json = AirflowJsonProvider(flask_app)
csrf.init_app(flask_app)
diff --git a/airflow/www/utils.py b/airflow/www/utils.py
index 596fc4218e..3429d6a140 100644
--- a/airflow/www/utils.py
+++ b/airflow/www/utils.py
@@ -24,7 +24,7 @@ from typing import Any
from urllib.parse import urlencode
import sqlalchemy as sqla
-from flask import Response, request, url_for
+from flask import request, url_for
from flask.helpers import flash
from flask_appbuilder.forms import FieldConverter
from flask_appbuilder.models.filters import BaseFilter
@@ -47,7 +47,6 @@ from airflow.models.taskinstance import TaskInstance
from airflow.utils import timezone
from airflow.utils.code_utils import get_python_source
from airflow.utils.helpers import alchemy_to_dict
-from airflow.utils.json import AirflowJsonEncoder
from airflow.utils.state import State, TaskInstanceState
from airflow.www.forms import DateTimeWithTimezoneField
from airflow.www.widgets import AirflowDateTimePickerWidget
@@ -322,13 +321,6 @@ def epoch(dttm):
return (int(time.mktime(dttm.timetuple())) * 1000,)
-def json_response(obj):
- """Returns a json response from a json serializable python object"""
- return Response(
- response=json.dumps(obj, indent=4, cls=AirflowJsonEncoder), status=200, mimetype="application/json"
- )
-
-
def make_cache_key(*args, **kwargs):
"""Used by cache to get a unique key per URL"""
path = request.path
diff --git a/airflow/www/views.py b/airflow/www/views.py
index 5d609d044b..65e825510d 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -37,6 +37,7 @@ from typing import Any, Callable
from urllib.parse import parse_qsl, unquote, urlencode, urlparse
import configupdater
+import flask.json
import lazy_object_proxy
import markupsafe
import nvd3
@@ -107,7 +108,7 @@ from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dependencies_deps import RUNNING_DEPS, SCHEDULER_QUEUED_DEPS
from airflow.timetables.base import DataInterval, TimeRestriction
from airflow.timetables.interval import CronDataIntervalTimetable
-from airflow.utils import json as utils_json, timezone, yaml
+from airflow.utils import timezone, yaml
from airflow.utils.airflow_flask_app import get_airflow_app
from airflow.utils.dag_edges import dag_edges
from airflow.utils.dates import infer_time_unit, scale_time_units
@@ -575,7 +576,7 @@ class Airflow(AirflowBaseView):
'latest_scheduler_heartbeat': latest_scheduler_heartbeat,
}
- return wwwutils.json_response(payload)
+ return flask.json.jsonify(payload)
@expose('/home')
@auth.has_access(
@@ -856,7 +857,7 @@ class Airflow(AirflowBaseView):
filter_dag_ids = allowed_dag_ids
if not filter_dag_ids:
- return wwwutils.json_response({})
+ return flask.json.jsonify({})
payload = {}
dag_state_stats = dag_state_stats.filter(dr.dag_id.in_(filter_dag_ids))
@@ -873,7 +874,7 @@ class Airflow(AirflowBaseView):
count = data.get(dag_id, {}).get(state, 0)
payload[dag_id].append({'state': state, 'count': count})
- return wwwutils.json_response(payload)
+ return flask.json.jsonify(payload)
@expose('/task_stats', methods=['POST'])
@auth.has_access(
@@ -889,7 +890,7 @@ class Airflow(AirflowBaseView):
allowed_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user)
if not allowed_dag_ids:
- return wwwutils.json_response({})
+ return flask.json.jsonify({})
# Filter by post parameters
selected_dag_ids = {unquote(dag_id) for dag_id in request.form.getlist('dag_ids') if dag_id}
@@ -983,7 +984,7 @@ class Airflow(AirflowBaseView):
for state in State.task_states:
count = data.get(dag_id, {}).get(state, 0)
payload[dag_id].append({'state': state, 'count': count})
- return wwwutils.json_response(payload)
+ return flask.json.jsonify(payload)
@expose('/last_dagruns', methods=['POST'])
@auth.has_access(
@@ -1006,7 +1007,7 @@ class Airflow(AirflowBaseView):
filter_dag_ids = allowed_dag_ids
if not filter_dag_ids:
- return wwwutils.json_response({})
+ return flask.json.jsonify({})
last_runs_subquery = (
session.query(
@@ -1046,7 +1047,7 @@ class Airflow(AirflowBaseView):
}
for r in query
}
- return wwwutils.json_response(resp)
+ return flask.json.jsonify(resp)
@expose('/code')
@auth.has_access(
@@ -2104,7 +2105,7 @@ class Airflow(AirflowBaseView):
filter_dag_ids = allowed_dag_ids
if not filter_dag_ids:
- return wwwutils.json_response([])
+ return flask.json.jsonify([])
dags = (
session.query(DagRun.dag_id, sqla.func.count(DagRun.id))
@@ -2127,7 +2128,7 @@ class Airflow(AirflowBaseView):
'max_active_runs': max_active_runs,
}
)
- return wwwutils.json_response(payload)
+ return flask.json.jsonify(payload)
def _mark_dagrun_state_as_failed(self, dag_id, dag_run_id, confirmed):
if not dag_run_id:
@@ -3410,7 +3411,7 @@ class Airflow(AirflowBaseView):
for ti in dag.get_task_instances(dttm, dttm)
}
- return json.dumps(task_instances, cls=utils_json.AirflowJsonEncoder)
+ return flask.json.jsonify(task_instances)
@expose('/object/grid_data')
@auth.has_access(
@@ -3465,7 +3466,7 @@ class Airflow(AirflowBaseView):
}
# avoid spaces to reduce payload size
return (
- htmlsafe_json_dumps(data, separators=(',', ':'), cls=utils_json.AirflowJsonEncoder),
+ htmlsafe_json_dumps(data, separators=(',', ':'), dumps=flask.json.dumps),
{'Content-Type': 'application/json; charset=utf-8'},
)
@@ -3508,7 +3509,7 @@ class Airflow(AirflowBaseView):
.all()
]
return (
- htmlsafe_json_dumps(data, separators=(',', ':'), cls=utils_json.AirflowJsonEncoder),
+ htmlsafe_json_dumps(data, separators=(',', ':'), dumps=flask.json.dumps),
{'Content-Type': 'application/json; charset=utf-8'},
)
@@ -3545,7 +3546,7 @@ class Airflow(AirflowBaseView):
}
return (
- htmlsafe_json_dumps(data, separators=(',', ':'), cls=utils_json.AirflowJsonEncoder),
+ htmlsafe_json_dumps(data, separators=(',', ':'), dumps=flask.json.dumps),
{'Content-Type': 'application/json; charset=utf-8'},
)
@@ -5205,7 +5206,7 @@ class AutocompleteView(AirflowBaseView):
query = unquote(request.args.get('query', ''))
if not query:
- return wwwutils.json_response([])
+ return flask.json.jsonify([])
# Provide suggestions of dag_ids and owners
dag_ids_query = session.query(
@@ -5239,7 +5240,7 @@ class AutocompleteView(AirflowBaseView):
payload = [
row._asdict() for row in dag_ids_query.union(owners_query).order_by('name').limit(10).all()
]
- return wwwutils.json_response(payload)
+ return flask.json.jsonify(payload)
class DagDependenciesView(AirflowBaseView):
diff --git a/tests/www/test_app.py b/tests/www/test_app.py
index e62bda71d0..d82dda1d7a 100644
--- a/tests/www/test_app.py
+++ b/tests/www/test_app.py
@@ -240,3 +240,13 @@ class TestFlaskCli:
output = capsys.readouterr()
assert "/login/" in output.out
+
+
+def test_app_can_json_serialize_k8s_pod():
+ # This is mostly testing that we have correctly configured the JSON provider to use. Testing the k8s pos
+ # is a side-effect of that.
+ k8s = pytest.importorskip('kubernetes.client.models')
+
+ pod = k8s.V1Pod(spec=k8s.V1PodSpec(containers=[k8s.V1Container(name="base")]))
+ app = application.cached_app(testing=True)
+ assert app.json.dumps(pod) == '{"spec": {"containers": [{"name": "base"}]}}'