You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ma...@apache.org on 2018/08/03 10:08:14 UTC
[incubator-mxnet] branch master updated: Added default tolerance
levels for regression checks for MBCC (#12006)
This is an automated email from the ASF dual-hosted git repository.
marcoabreu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 2534164 Added default tolerance levels for regression checks for MBCC (#12006)
2534164 is described below
commit 25341648365598a9a123f033bf92ce7fb51c0a39
Author: Piyush Ghai <gh...@osu.edu>
AuthorDate: Fri Aug 3 03:08:05 2018 -0700
Added default tolerance levels for regression checks for MBCC (#12006)
* Added tolerance level for assert_almost_equal for MBCC
* Nudge to CI
---
tests/nightly/model_backwards_compatibility_check/common.py | 2 ++
.../model_backwards_compat_inference.py | 8 ++++----
2 files changed, 6 insertions(+), 4 deletions(-)
diff --git a/tests/nightly/model_backwards_compatibility_check/common.py b/tests/nightly/model_backwards_compatibility_check/common.py
index 4c61cc4..8950a92 100644
--- a/tests/nightly/model_backwards_compatibility_check/common.py
+++ b/tests/nightly/model_backwards_compatibility_check/common.py
@@ -41,6 +41,8 @@ data_folder = 'mxnet-model-backwards-compatibility-data'
backslash = '/'
s3 = boto3.resource('s3')
ctx = mx.cpu(0)
+atol_default = 1e-5
+rtol_default = 1e-5
def get_model_path(model_name):
diff --git a/tests/nightly/model_backwards_compatibility_check/model_backwards_compat_inference.py b/tests/nightly/model_backwards_compatibility_check/model_backwards_compat_inference.py
index ae368e3..5d63e7e 100644
--- a/tests/nightly/model_backwards_compatibility_check/model_backwards_compat_inference.py
+++ b/tests/nightly/model_backwards_compatibility_check/model_backwards_compat_inference.py
@@ -44,7 +44,7 @@ def test_module_checkpoint_api():
old_inference_results = load_inference_results(model_name)
inference_results = loaded_model.predict(data_iter)
# Check whether they are equal or not ?
- assert_almost_equal(inference_results.asnumpy(), old_inference_results.asnumpy())
+ assert_almost_equal(inference_results.asnumpy(), old_inference_results.asnumpy(), rtol=rtol_default, atol=atol_default)
clean_model_files(model_files, model_name)
logging.info('=================================')
@@ -69,7 +69,7 @@ def test_lenet_gluon_load_params_api():
loaded_model.load_params(model_name + '-params')
output = loaded_model(test_data)
old_inference_results = mx.nd.load(model_name + '-inference')['inference']
- assert_almost_equal(old_inference_results.asnumpy(), output.asnumpy())
+ assert_almost_equal(old_inference_results.asnumpy(), output.asnumpy(), rtol=rtol_default, atol=atol_default)
clean_model_files(model_files, model_name)
logging.info('=================================')
logging.info('Assertion passed for model : %s' % model_name)
@@ -92,7 +92,7 @@ def test_lenet_gluon_hybrid_imports_api():
loaded_model = gluon.SymbolBlock.imports(model_name + '-symbol.json', ['data'], model_name + '-0000.params')
output = loaded_model(test_data)
old_inference_results = mx.nd.load(model_name + '-inference')['inference']
- assert_almost_equal(old_inference_results.asnumpy(), output.asnumpy())
+ assert_almost_equal(old_inference_results.asnumpy(), output.asnumpy(), rtol=rtol_default, atol=atol_default)
clean_model_files(model_files, model_name)
logging.info('=================================')
logging.info('Assertion passed for model : %s' % model_name)
@@ -124,7 +124,7 @@ def test_lstm_gluon_load_parameters_api():
loaded_model.load_parameters(model_name + '-params')
output = loaded_model(test_data)
old_inference_results = mx.nd.load(model_name + '-inference')['inference']
- assert_almost_equal(old_inference_results.asnumpy(), output.asnumpy())
+ assert_almost_equal(old_inference_results.asnumpy(), output.asnumpy(), rtol=rtol_default, atol=atol_default)
clean_model_files(model_files, model_name)
logging.info('=================================')
logging.info('Assertion passed for model : %s' % model_name)