You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/06/24 15:49:22 UTC
[incubator-tvm] branch master updated: Fix serialization of inf
float value (#5912)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new fcaba98 Fix serialization of inf float value (#5912)
fcaba98 is described below
commit fcaba9892775e817148c1de36d917a4d5566a003
Author: lixiaoquan <ra...@163.com>
AuthorDate: Wed Jun 24 23:49:13 2020 +0800
Fix serialization of inf float value (#5912)
---
src/node/serialization.cc | 12 +++++++++---
tests/python/unittest/test_node_reflection.py | 11 +++++++++++
2 files changed, 20 insertions(+), 3 deletions(-)
diff --git a/src/node/serialization.cc b/src/node/serialization.cc
index 4382579..42767c2 100644
--- a/src/node/serialization.cc
+++ b/src/node/serialization.cc
@@ -352,9 +352,15 @@ class JSONAttrSetter : public AttrVisitor {
template <typename T>
void ParseValue(const char* key, T* value) const {
std::istringstream is(GetValue(key));
- is >> *value;
- if (is.fail()) {
- LOG(FATAL) << "Wrong value format for field " << key;
+ if (is.str() == "inf") {
+ *value = std::numeric_limits<T>::infinity();
+ } else if (is.str() == "-inf") {
+ *value = -std::numeric_limits<T>::infinity();
+ } else {
+ is >> *value;
+ if (is.fail()) {
+ LOG(FATAL) << "Wrong value format for field " << key;
+ }
}
}
void Visit(const char* key, double* value) final { ParseValue(key, value); }
diff --git a/tests/python/unittest/test_node_reflection.py b/tests/python/unittest/test_node_reflection.py
index 3a7318c..d375fa0 100644
--- a/tests/python/unittest/test_node_reflection.py
+++ b/tests/python/unittest/test_node_reflection.py
@@ -28,6 +28,16 @@ def test_const_saveload_json():
zz = tvm.ir.load_json(json_str)
tvm.ir.assert_structural_equal(zz, z, map_free_vars=True)
+def _test_infinity_value(value, dtype):
+ x = tvm.tir.const(value, dtype)
+ json_str = tvm.ir.save_json(x)
+ tvm.ir.assert_structural_equal(x, tvm.ir.load_json(json_str))
+
+def test_infinity_value():
+ _test_infinity_value(float("inf"), 'float64')
+ _test_infinity_value(float("-inf"), 'float64')
+ _test_infinity_value(float("inf"), 'float32')
+ _test_infinity_value(float("-inf"), 'float32')
def test_make_smap():
# save load json
@@ -145,3 +155,4 @@ if __name__ == "__main__":
test_make_sum()
test_pass_config()
test_dict()
+ test_infinity_value()