You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by xi...@apache.org on 2022/07/07 19:44:24 UTC

[tvm] branch main updated: [MetaSchedule][Minor] Stability Improvements (#12014)

This is an automated email from the ASF dual-hosted git repository.

xiyou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 013d5e8fcb [MetaSchedule][Minor] Stability Improvements (#12014)
013d5e8fcb is described below

commit 013d5e8fcbd94fb3a0c5c0cdcaea03af43c464aa
Author: Xiyou Zhou <xi...@octoml.ai>
AuthorDate: Thu Jul 7 12:44:18 2022 -0700

    [MetaSchedule][Minor] Stability Improvements (#12014)
    
    * Fix tuning util for uint8.
    
    * Change to check runner_result.
    
    * Revert change to let cost model learn.
---
 python/tvm/meta_schedule/testing/tune_utils.py | 26 +++++++++++++-------------
 1 file changed, 13 insertions(+), 13 deletions(-)

diff --git a/python/tvm/meta_schedule/testing/tune_utils.py b/python/tvm/meta_schedule/testing/tune_utils.py
index aad8496a46..fe0984d51c 100644
--- a/python/tvm/meta_schedule/testing/tune_utils.py
+++ b/python/tvm/meta_schedule/testing/tune_utils.py
@@ -48,21 +48,21 @@ def generate_input_data(
     """
     if input_dtype.startswith("float"):
         return np.random.uniform(size=input_shape).astype(input_dtype)
-    if input_dtype in ["uint8", "int8"]:
-        return np.random.randint(
-            low=0,
-            high=127,
-            size=input_shape,
-            dtype="int32",  # TODO(zxybazh): fix the datatype when int8 / uint8 is supported better
+    if low is None or high is None:
+        warnings.warn(
+            f"Model input value range for shape {input_shape} of {input_dtype} is not set!"
         )
-    if input_dtype in ["int32", "int64"]:
-        if low is None or high is None:
-            warnings.warn(
-                "Model input value range for shape {input_shape} of {input_dtype} is not set!"
-            )
+    range_map = {
+        "uint8": (0, 255),
+        "int8": (-128, 127),
+        "int32": (0, 10000),
+        "int64": (0, 10000),
+    }
+    if input_dtype in range_map:
+        _low, _high = range_map[input_dtype]
         return np.random.randint(
-            low=0 if low is None else low,
-            high=10000 if high is None else high,
+            low=_low if low is None else low,
+            high=_high if high is None else high,
             size=input_shape,
             dtype=input_dtype,
         )