You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2020/08/02 06:59:46 UTC

[airflow] branch master updated: Add unit tests for mlengine_prediction_summary (#10022)

This is an automated email from the ASF dual-hosted git repository.

kamilbregula pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new ca3fa76  Add unit tests for mlengine_prediction_summary (#10022)
ca3fa76 is described below

commit ca3fa76b17fad8e05b85162120537a9ac00b6814
Author: Shekhar Singh <sh...@gmail.com>
AuthorDate: Sun Aug 2 12:29:07 2020 +0530

    Add unit tests for mlengine_prediction_summary (#10022)
---
 .../utils/test_mlengine_prediction_summary.py      | 95 ++++++++++++++++++++++
 tests/test_project_structure.py                    |  1 -
 2 files changed, 95 insertions(+), 1 deletion(-)

diff --git a/tests/providers/google/cloud/utils/test_mlengine_prediction_summary.py b/tests/providers/google/cloud/utils/test_mlengine_prediction_summary.py
new file mode 100644
index 0000000..ae9e0e7
--- /dev/null
+++ b/tests/providers/google/cloud/utils/test_mlengine_prediction_summary.py
@@ -0,0 +1,95 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import base64
+import binascii
+import unittest
+
+import dill
+import mock
+
+try:
+    from airflow.providers.google.cloud.utils import mlengine_prediction_summary
+except ImportError as e:
+    if 'apache_beam' in str(e):
+        raise unittest.SkipTest(f"package apache_beam not present. Skipping all tests in {__name__}")
+
+
+class TestJsonCode(unittest.TestCase):
+    def test_encode(self):
+        self.assertEqual(b'{"a": 1}', mlengine_prediction_summary.JsonCoder.encode({'a': 1}))
+
+    def test_decode(self):
+        self.assertEqual({'a': 1}, mlengine_prediction_summary.JsonCoder.decode('{"a": 1}'))
+
+
+class TestMakeSummary(unittest.TestCase):
+    def test_make_summary(self):
+        print(mlengine_prediction_summary.MakeSummary(1, lambda x: x, []))
+
+    def test_run_without_all_arguments_should_raise_exception(self):
+        with self.assertRaises(SystemExit):
+            mlengine_prediction_summary.run()
+
+        with self.assertRaises(SystemExit):
+            mlengine_prediction_summary.run([
+                "--prediction_path=some/path",
+            ])
+
+        with self.assertRaises(SystemExit):
+            mlengine_prediction_summary.run([
+                "--prediction_path=some/path",
+                "--metric_fn_encoded=encoded_text",
+            ])
+
+    def test_run_should_fail_for_invalid_encoded_fn(self):
+        with self.assertRaises(binascii.Error):
+            mlengine_prediction_summary.run([
+                "--prediction_path=some/path",
+                "--metric_fn_encoded=invalid_encoded_text",
+                "--metric_keys=a",
+            ])
+
+    def test_run_should_fail_if_enc_fn_is_not_callable(self):
+        non_callable_value = 1
+        fn_enc = base64.b64encode(dill.dumps(non_callable_value)).decode('utf-8')
+
+        with self.assertRaises(ValueError):
+            mlengine_prediction_summary.run([
+                "--prediction_path=some/path",
+                "--metric_fn_encoded=" + fn_enc,
+                "--metric_keys=a",
+            ])
+
+    @mock.patch.object(mlengine_prediction_summary.beam.pipeline, "PipelineOptions")
+    @mock.patch.object(mlengine_prediction_summary.beam, "Pipeline")
+    @mock.patch.object(mlengine_prediction_summary.beam.io, "ReadFromText")
+    def test_run_should_not_fail_with_valid_fn(self, io_mock, pipeline_obj_mock, pipeline_mock):
+        def metric_function():
+            return 1
+
+        fn_enc = base64.b64encode(dill.dumps(metric_function)).decode('utf-8')
+
+        mlengine_prediction_summary.run([
+            "--prediction_path=some/path",
+            "--metric_fn_encoded=" + fn_enc,
+            "--metric_keys=a",
+        ])
+
+        pipeline_mock.assert_called_once_with([])
+        pipeline_obj_mock.assert_called_once()
+        io_mock.assert_called_once()
diff --git a/tests/test_project_structure.py b/tests/test_project_structure.py
index 621ffea..120a018 100644
--- a/tests/test_project_structure.py
+++ b/tests/test_project_structure.py
@@ -30,7 +30,6 @@ ROOT_FOLDER = os.path.realpath(
 MISSING_TEST_FILES = {
     'tests/providers/google/cloud/log/test_gcs_task_handler.py',
     'tests/providers/google/cloud/operators/test_datastore.py',
-    'tests/providers/google/cloud/utils/test_mlengine_prediction_summary.py',
     'tests/providers/microsoft/azure/sensors/test_azure_cosmos.py',
     'tests/providers/microsoft/azure/log/test_wasb_task_handler.py',
 }