You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/12/18 01:45:01 UTC

[tvm] branch main updated: [BugFix][TVMScript] Parser crash (#13630)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4096548d13 [BugFix][TVMScript] Parser crash (#13630)
4096548d13 is described below

commit 4096548d13cc8add8fe1f89d54f0968f89570461
Author: lightzhan <11...@qq.com>
AuthorDate: Sun Dec 18 09:44:49 2022 +0800

    [BugFix][TVMScript] Parser crash (#13630)
    
    This PR tries to fix the crash of parser when the old value of a var is an array but the new value is not. For example:
    
    ```python
    from tvm.script import tir as T
    def func_wrapper(shape, dtype):
        @T.prim_func
        def test_case():
            a = T.alloc_buffer(shape, dtype=dtype)
    
        return test_case
    
    
    if __name__ == "__main__":
        a = np.zeros((10, 10), dtype="int8")
        print(func_wrapper((256, 256), dtype="int8").script())
    ```
    
    In the above code, there are two assignment to var 'a'. In the global scope, its value is a numpy array. But it is a Buffer in the prim function. There is a table named 'name2value' to track the value of vars like 'a' here.
    When the parser wants to update its value, it will compare the value between the new and the old assignment. Here the problem comes. When we use '==' to compare an array with a value, the result is an array too, which can not be used as a condition of a if stmt directly. So, the code above will emit an error:
    
    ```shell
    error: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
     --> /workspace/code_newest/tvm/private_test/test_meta_programming.py:16:9
        |
     16 |          a = T.alloc_buffer(shape, dtype=dtype)
        |          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    ```
    
    This PR fixes this by change "==" to "is".
    
    Co-authored-by: lightzhan-intellif <zh...@intellif.com>
---
 python/tvm/script/parser/core/parser.py            |  8 ++++++--
 tests/python/unittest/test_tvmscript_regression.py | 15 +++++++++++++++
 2 files changed, 21 insertions(+), 2 deletions(-)

diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py
index c6d43f11cb..7c699c42ae 100644
--- a/python/tvm/script/parser/core/parser.py
+++ b/python/tvm/script/parser/core/parser.py
@@ -19,6 +19,7 @@
 from collections import defaultdict
 from contextlib import contextmanager
 from typing import Any, Callable, Dict, List, Optional, Set, Union
+import numpy as np
 from tvm._ffi.base import TVMError
 
 from tvm.error import DiagnosticError
@@ -150,8 +151,11 @@ class VarTable:
             The options of whether variable shadowing allwed for this variable.
         """
         # Skip if the key and value are equal to those in the var_table
-        if self.name2value[var] and self.name2value[var][-1] == value:
-            return
+        if self.name2value[var] and isinstance(self.name2value[var][-1], type(value)):
+            if isinstance(value, np.ndarray) and (self.name2value[var][-1] == value).all():
+                return
+            elif self.name2value[var][-1] == value:
+                return
         if allow_shadowing and var in self.frames[-1].vars:
             # Shadowing
             self.name2value[var][-1] = value
diff --git a/tests/python/unittest/test_tvmscript_regression.py b/tests/python/unittest/test_tvmscript_regression.py
index 3ad8090893..05c1665ea2 100644
--- a/tests/python/unittest/test_tvmscript_regression.py
+++ b/tests/python/unittest/test_tvmscript_regression.py
@@ -45,5 +45,20 @@ def test_multi_element_array_in_outmost_namespace():
     tvm.ir.assert_structural_equal(func, rt_func)
 
 
+def test_different_dtype_assignment_to_var():
+    @T.prim_func
+    def test_case():
+        a = T.alloc_buffer((10, 10), dtype="int8")
+
+    @T.prim_func
+    def func_ref():
+        a = T.alloc_buffer([10, 10], dtype="int8")
+        T.evaluate(0)
+
+    tvm.ir.assert_structural_equal(test_case, func_ref)
+
+
 if __name__ == "__main__":
+    a = numpy.zeros((10, 10), dtype="int8")
     test_multi_element_array_in_outmost_namespace()
+    test_different_dtype_assignment_to_var()