You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by bo...@apache.org on 2018/01/11 19:36:00 UTC

incubator-airflow git commit: [AIRFLOW-1949] Fix var upload, str() produces "b'...'" which is not json

Repository: incubator-airflow
Updated Branches:
  refs/heads/master d9bbb6312 -> 1f3b60792


[AIRFLOW-1949] Fix var upload, str() produces "b'...'" which is not json

Closes #2899 from j16r/bugfix/AIRFLOW-
1949_config_upload


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/1f3b6079
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/1f3b6079
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/1f3b6079

Branch: refs/heads/master
Commit: 1f3b607925eaacdb11831f258069623054c40eae
Parents: d9bbb63
Author: John Barker <je...@gmail.com>
Authored: Thu Jan 11 20:35:55 2018 +0100
Committer: Bolke de Bruin <bo...@xs4all.nl>
Committed: Thu Jan 11 20:35:55 2018 +0100

----------------------------------------------------------------------
 airflow/www/views.py    | 10 ++++++----
 tests/www/test_views.py | 46 ++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 52 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/1f3b6079/airflow/www/views.py
----------------------------------------------------------------------
diff --git a/airflow/www/views.py b/airflow/www/views.py
index 6bcb66d..716e9fe 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -28,6 +28,7 @@ import math
 import json
 import bleach
 import pendulum
+import codecs
 from collections import defaultdict
 
 import inspect
@@ -89,6 +90,8 @@ from airflow.www.validators import GreaterEqualThan
 QUERY_LIMIT = 100000
 CHART_LIMIT = 200000
 
+UTF8_READER = codecs.getreader('utf-8')
+
 dagbag = models.DagBag(settings.DAGS_FOLDER)
 
 login_required = airflow.login.login_required
@@ -1790,10 +1793,9 @@ class Airflow(BaseView):
     @wwwutils.action_logging
     def varimport(self):
         try:
-            out = str(request.files['file'].read())
-            d = json.loads(out)
-        except Exception:
-            flash("Missing file or syntax error.")
+            d = json.load(UTF8_READER(request.files['file']))
+        except Exception as e:
+            flash("Missing file or syntax error: {}.".format(e))
         else:
             for k, v in d.items():
                 models.Variable.set(k, v, serialize_json=isinstance(v, dict))

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/1f3b6079/tests/www/test_views.py
----------------------------------------------------------------------
diff --git a/tests/www/test_views.py b/tests/www/test_views.py
index 017176d..6ea8db2 100644
--- a/tests/www/test_views.py
+++ b/tests/www/test_views.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import io
 import copy
 import logging.config
 import os
@@ -378,5 +379,50 @@ class TestLogView(unittest.TestCase):
                       response.data.decode('utf-8'))
 
 
+class TestVarImportView(unittest.TestCase):
+
+    IMPORT_ENDPOINT = '/admin/airflow/varimport'
+
+    @classmethod
+    def setUpClass(cls):
+        super(TestVarImportView, cls).setUpClass()
+        session = Session()
+        session.query(models.User).delete()
+        session.commit()
+        user = models.User(username='airflow')
+        session.add(user)
+        session.commit()
+        session.close()
+
+    def setUp(self):
+        super(TestVarImportView, self).setUp()
+        configuration.load_test_config()
+        app = application.create_app(testing=True)
+        app.config['WTF_CSRF_METHODS'] = []
+        self.app = app.test_client()
+
+    def tearDown(self):
+        super(TestVarImportView, self).tearDown()
+
+    @classmethod
+    def tearDownClass(cls):
+        session = Session()
+        session.query(models.User).delete()
+        session.commit()
+        session.close()
+        super(TestVarImportView, cls).tearDownClass()
+
+    def test_import_variables(self):
+        response = self.app.post(
+            self.IMPORT_ENDPOINT,
+            data={'file': (io.BytesIO(b'{"KEY": "VALUE"}'), 'test.json')},
+            follow_redirects=True
+        )
+        self.assertEqual(response.status_code, 200)
+        body = response.data.decode('utf-8')
+        self.assertIn('KEY', body)
+        self.assertIn('VALUE', body)
+
+
 if __name__ == '__main__':
     unittest.main()