You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by je...@apache.org on 2022/09/23 03:19:33 UTC

[airflow] 12/13: Correctly set json_provider_class on Flask app so it uses our encoder (#26554)

This is an automated email from the ASF dual-hosted git repository.

jedcunningham pushed a commit to branch v2-4-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 5ae2ae531bba1d15975ba501c13f130959e8dd38
Author: Ash Berlin-Taylor <as...@apache.org>
AuthorDate: Wed Sep 21 22:12:39 2022 +0100

    Correctly set json_provider_class on Flask app so it uses our encoder (#26554)
    
    Setting `json_provider_class` where we did had no effect, as it turns
    out `Flask()` sets `self.json = self.json_provider_class(self)`, so we
    were setting it too late.
    
    (cherry picked from commit 378dfbe2fe266f17859dbabd34b9bc8cd5c904ab)
---
 airflow/utils/json.py | 20 ++++++++++++++++++--
 airflow/www/app.py    |  5 +++--
 airflow/www/utils.py  | 10 +---------
 airflow/www/views.py  | 33 +++++++++++++++++----------------
 tests/www/test_app.py | 10 ++++++++++
 5 files changed, 49 insertions(+), 29 deletions(-)

diff --git a/airflow/utils/json.py b/airflow/utils/json.py
index fcc4eedd6e..ff11097824 100644
--- a/airflow/utils/json.py
+++ b/airflow/utils/json.py
@@ -17,11 +17,12 @@
 # under the License.
 from __future__ import annotations
 
+import json
 import logging
 from datetime import date, datetime
 from decimal import Decimal
 
-from flask.json import JSONEncoder
+from flask.json.provider import JSONProvider
 
 from airflow.utils.timezone import convert_to_utc, is_naive
 
@@ -40,7 +41,7 @@ except ImportError:
 log = logging.getLogger(__name__)
 
 
-class AirflowJsonEncoder(JSONEncoder):
+class AirflowJsonEncoder(json.JSONEncoder):
     """Custom Airflow json encoder implementation."""
 
     def __init__(self, *args, **kwargs):
@@ -107,3 +108,18 @@ class AirflowJsonEncoder(JSONEncoder):
                 return {}
 
         raise TypeError(f"Object of type '{obj.__class__.__name__}' is not JSON serializable")
+
+
+class AirflowJsonProvider(JSONProvider):
+    """JSON Provider for Flask app to use AirflowJsonEncoder."""
+
+    ensure_ascii: bool = True
+    sort_keys: bool = True
+
+    def dumps(self, obj, **kwargs):
+        kwargs.setdefault('ensure_ascii', self.ensure_ascii)
+        kwargs.setdefault('sort_keys', self.sort_keys)
+        return json.dumps(obj, **kwargs, cls=AirflowJsonEncoder)
+
+    def loads(self, s: str | bytes, **kwargs):
+        return json.loads(s, **kwargs)
diff --git a/airflow/www/app.py b/airflow/www/app.py
index d40f3badb8..b67314c99a 100644
--- a/airflow/www/app.py
+++ b/airflow/www/app.py
@@ -32,7 +32,7 @@ from airflow.configuration import conf
 from airflow.exceptions import AirflowConfigException, RemovedInAirflow3Warning
 from airflow.logging_config import configure_logging
 from airflow.models import import_all_models
-from airflow.utils.json import AirflowJsonEncoder
+from airflow.utils.json import AirflowJsonProvider
 from airflow.www.extensions.init_appbuilder import init_appbuilder
 from airflow.www.extensions.init_appbuilder_links import init_appbuilder_links
 from airflow.www.extensions.init_dagbag import init_dagbag
@@ -109,7 +109,8 @@ def create_app(config=None, testing=False):
         flask_app.config['SQLALCHEMY_ENGINE_OPTIONS'] = settings.prepare_engine_args()
 
     # Configure the JSON encoder used by `|tojson` filter from Flask
-    flask_app.json_provider_class = AirflowJsonEncoder
+    flask_app.json_provider_class = AirflowJsonProvider
+    flask_app.json = AirflowJsonProvider(flask_app)
 
     csrf.init_app(flask_app)
 
diff --git a/airflow/www/utils.py b/airflow/www/utils.py
index 596fc4218e..3429d6a140 100644
--- a/airflow/www/utils.py
+++ b/airflow/www/utils.py
@@ -24,7 +24,7 @@ from typing import Any
 from urllib.parse import urlencode
 
 import sqlalchemy as sqla
-from flask import Response, request, url_for
+from flask import request, url_for
 from flask.helpers import flash
 from flask_appbuilder.forms import FieldConverter
 from flask_appbuilder.models.filters import BaseFilter
@@ -47,7 +47,6 @@ from airflow.models.taskinstance import TaskInstance
 from airflow.utils import timezone
 from airflow.utils.code_utils import get_python_source
 from airflow.utils.helpers import alchemy_to_dict
-from airflow.utils.json import AirflowJsonEncoder
 from airflow.utils.state import State, TaskInstanceState
 from airflow.www.forms import DateTimeWithTimezoneField
 from airflow.www.widgets import AirflowDateTimePickerWidget
@@ -322,13 +321,6 @@ def epoch(dttm):
     return (int(time.mktime(dttm.timetuple())) * 1000,)
 
 
-def json_response(obj):
-    """Returns a json response from a json serializable python object"""
-    return Response(
-        response=json.dumps(obj, indent=4, cls=AirflowJsonEncoder), status=200, mimetype="application/json"
-    )
-
-
 def make_cache_key(*args, **kwargs):
     """Used by cache to get a unique key per URL"""
     path = request.path
diff --git a/airflow/www/views.py b/airflow/www/views.py
index 5d609d044b..65e825510d 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -37,6 +37,7 @@ from typing import Any, Callable
 from urllib.parse import parse_qsl, unquote, urlencode, urlparse
 
 import configupdater
+import flask.json
 import lazy_object_proxy
 import markupsafe
 import nvd3
@@ -107,7 +108,7 @@ from airflow.ti_deps.dep_context import DepContext
 from airflow.ti_deps.dependencies_deps import RUNNING_DEPS, SCHEDULER_QUEUED_DEPS
 from airflow.timetables.base import DataInterval, TimeRestriction
 from airflow.timetables.interval import CronDataIntervalTimetable
-from airflow.utils import json as utils_json, timezone, yaml
+from airflow.utils import timezone, yaml
 from airflow.utils.airflow_flask_app import get_airflow_app
 from airflow.utils.dag_edges import dag_edges
 from airflow.utils.dates import infer_time_unit, scale_time_units
@@ -575,7 +576,7 @@ class Airflow(AirflowBaseView):
             'latest_scheduler_heartbeat': latest_scheduler_heartbeat,
         }
 
-        return wwwutils.json_response(payload)
+        return flask.json.jsonify(payload)
 
     @expose('/home')
     @auth.has_access(
@@ -856,7 +857,7 @@ class Airflow(AirflowBaseView):
             filter_dag_ids = allowed_dag_ids
 
         if not filter_dag_ids:
-            return wwwutils.json_response({})
+            return flask.json.jsonify({})
 
         payload = {}
         dag_state_stats = dag_state_stats.filter(dr.dag_id.in_(filter_dag_ids))
@@ -873,7 +874,7 @@ class Airflow(AirflowBaseView):
                 count = data.get(dag_id, {}).get(state, 0)
                 payload[dag_id].append({'state': state, 'count': count})
 
-        return wwwutils.json_response(payload)
+        return flask.json.jsonify(payload)
 
     @expose('/task_stats', methods=['POST'])
     @auth.has_access(
@@ -889,7 +890,7 @@ class Airflow(AirflowBaseView):
         allowed_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user)
 
         if not allowed_dag_ids:
-            return wwwutils.json_response({})
+            return flask.json.jsonify({})
 
         # Filter by post parameters
         selected_dag_ids = {unquote(dag_id) for dag_id in request.form.getlist('dag_ids') if dag_id}
@@ -983,7 +984,7 @@ class Airflow(AirflowBaseView):
             for state in State.task_states:
                 count = data.get(dag_id, {}).get(state, 0)
                 payload[dag_id].append({'state': state, 'count': count})
-        return wwwutils.json_response(payload)
+        return flask.json.jsonify(payload)
 
     @expose('/last_dagruns', methods=['POST'])
     @auth.has_access(
@@ -1006,7 +1007,7 @@ class Airflow(AirflowBaseView):
             filter_dag_ids = allowed_dag_ids
 
         if not filter_dag_ids:
-            return wwwutils.json_response({})
+            return flask.json.jsonify({})
 
         last_runs_subquery = (
             session.query(
@@ -1046,7 +1047,7 @@ class Airflow(AirflowBaseView):
             }
             for r in query
         }
-        return wwwutils.json_response(resp)
+        return flask.json.jsonify(resp)
 
     @expose('/code')
     @auth.has_access(
@@ -2104,7 +2105,7 @@ class Airflow(AirflowBaseView):
             filter_dag_ids = allowed_dag_ids
 
         if not filter_dag_ids:
-            return wwwutils.json_response([])
+            return flask.json.jsonify([])
 
         dags = (
             session.query(DagRun.dag_id, sqla.func.count(DagRun.id))
@@ -2127,7 +2128,7 @@ class Airflow(AirflowBaseView):
                     'max_active_runs': max_active_runs,
                 }
             )
-        return wwwutils.json_response(payload)
+        return flask.json.jsonify(payload)
 
     def _mark_dagrun_state_as_failed(self, dag_id, dag_run_id, confirmed):
         if not dag_run_id:
@@ -3410,7 +3411,7 @@ class Airflow(AirflowBaseView):
                 for ti in dag.get_task_instances(dttm, dttm)
             }
 
-        return json.dumps(task_instances, cls=utils_json.AirflowJsonEncoder)
+        return flask.json.jsonify(task_instances)
 
     @expose('/object/grid_data')
     @auth.has_access(
@@ -3465,7 +3466,7 @@ class Airflow(AirflowBaseView):
             }
         # avoid spaces to reduce payload size
         return (
-            htmlsafe_json_dumps(data, separators=(',', ':'), cls=utils_json.AirflowJsonEncoder),
+            htmlsafe_json_dumps(data, separators=(',', ':'), dumps=flask.json.dumps),
             {'Content-Type': 'application/json; charset=utf-8'},
         )
 
@@ -3508,7 +3509,7 @@ class Airflow(AirflowBaseView):
                 .all()
             ]
         return (
-            htmlsafe_json_dumps(data, separators=(',', ':'), cls=utils_json.AirflowJsonEncoder),
+            htmlsafe_json_dumps(data, separators=(',', ':'), dumps=flask.json.dumps),
             {'Content-Type': 'application/json; charset=utf-8'},
         )
 
@@ -3545,7 +3546,7 @@ class Airflow(AirflowBaseView):
         }
 
         return (
-            htmlsafe_json_dumps(data, separators=(',', ':'), cls=utils_json.AirflowJsonEncoder),
+            htmlsafe_json_dumps(data, separators=(',', ':'), dumps=flask.json.dumps),
             {'Content-Type': 'application/json; charset=utf-8'},
         )
 
@@ -5205,7 +5206,7 @@ class AutocompleteView(AirflowBaseView):
         query = unquote(request.args.get('query', ''))
 
         if not query:
-            return wwwutils.json_response([])
+            return flask.json.jsonify([])
 
         # Provide suggestions of dag_ids and owners
         dag_ids_query = session.query(
@@ -5239,7 +5240,7 @@ class AutocompleteView(AirflowBaseView):
         payload = [
             row._asdict() for row in dag_ids_query.union(owners_query).order_by('name').limit(10).all()
         ]
-        return wwwutils.json_response(payload)
+        return flask.json.jsonify(payload)
 
 
 class DagDependenciesView(AirflowBaseView):
diff --git a/tests/www/test_app.py b/tests/www/test_app.py
index e62bda71d0..d82dda1d7a 100644
--- a/tests/www/test_app.py
+++ b/tests/www/test_app.py
@@ -240,3 +240,13 @@ class TestFlaskCli:
 
         output = capsys.readouterr()
         assert "/login/" in output.out
+
+
+def test_app_can_json_serialize_k8s_pod():
+    # This is mostly testing that we have correctly configured the JSON provider to use. Testing the k8s pos
+    # is a side-effect of that.
+    k8s = pytest.importorskip('kubernetes.client.models')
+
+    pod = k8s.V1Pod(spec=k8s.V1PodSpec(containers=[k8s.V1Container(name="base")]))
+    app = application.cached_app(testing=True)
+    assert app.json.dumps(pod) == '{"spec": {"containers": [{"name": "base"}]}}'