You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/06/24 15:49:22 UTC

[incubator-tvm] branch master updated: Fix serialization of inf float value (#5912)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new fcaba98  Fix serialization of inf float value (#5912)
fcaba98 is described below

commit fcaba9892775e817148c1de36d917a4d5566a003
Author: lixiaoquan <ra...@163.com>
AuthorDate: Wed Jun 24 23:49:13 2020 +0800

    Fix serialization of inf float value (#5912)
---
 src/node/serialization.cc                     | 12 +++++++++---
 tests/python/unittest/test_node_reflection.py | 11 +++++++++++
 2 files changed, 20 insertions(+), 3 deletions(-)

diff --git a/src/node/serialization.cc b/src/node/serialization.cc
index 4382579..42767c2 100644
--- a/src/node/serialization.cc
+++ b/src/node/serialization.cc
@@ -352,9 +352,15 @@ class JSONAttrSetter : public AttrVisitor {
   template <typename T>
   void ParseValue(const char* key, T* value) const {
     std::istringstream is(GetValue(key));
-    is >> *value;
-    if (is.fail()) {
-      LOG(FATAL) << "Wrong value format for field " << key;
+    if (is.str() == "inf") {
+      *value = std::numeric_limits<T>::infinity();
+    } else if (is.str() == "-inf") {
+      *value = -std::numeric_limits<T>::infinity();
+    } else {
+      is >> *value;
+      if (is.fail()) {
+        LOG(FATAL) << "Wrong value format for field " << key;
+      }
     }
   }
   void Visit(const char* key, double* value) final { ParseValue(key, value); }
diff --git a/tests/python/unittest/test_node_reflection.py b/tests/python/unittest/test_node_reflection.py
index 3a7318c..d375fa0 100644
--- a/tests/python/unittest/test_node_reflection.py
+++ b/tests/python/unittest/test_node_reflection.py
@@ -28,6 +28,16 @@ def test_const_saveload_json():
     zz = tvm.ir.load_json(json_str)
     tvm.ir.assert_structural_equal(zz, z, map_free_vars=True)
 
+def _test_infinity_value(value, dtype):
+    x = tvm.tir.const(value, dtype)
+    json_str = tvm.ir.save_json(x)
+    tvm.ir.assert_structural_equal(x, tvm.ir.load_json(json_str))
+
+def test_infinity_value():
+    _test_infinity_value(float("inf"), 'float64')
+    _test_infinity_value(float("-inf"), 'float64')
+    _test_infinity_value(float("inf"), 'float32')
+    _test_infinity_value(float("-inf"), 'float32')
 
 def test_make_smap():
     # save load json
@@ -145,3 +155,4 @@ if __name__ == "__main__":
     test_make_sum()
     test_pass_config()
     test_dict()
+    test_infinity_value()