You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/12/04 07:44:16 UTC

[tvm] branch main updated: [microNPU] Fix bug with re-reading in EncodeConstants (#9646)

This is an automated email from the ASF dual-hosted git repository.

manupa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new befa7f3  [microNPU] Fix bug with re-reading in EncodeConstants (#9646)
befa7f3 is described below

commit befa7f33587068b6bc2ee7f57dc726d7bb1dc365
Author: Matthew Barrett <55...@users.noreply.github.com>
AuthorDate: Sat Dec 4 07:43:56 2021 +0000

    [microNPU] Fix bug with re-reading in EncodeConstants (#9646)
    
    When a striping strategy that leads to weights
    being re-read was deployed, the logic in EncodeConstants
    failed. This adds a test for that case and fixed the
    pass so it handles it correctly.
    
    Change-Id: I6f54cdb7be69428e49c3b4208271cd3e6c192e5d
---
 .../tvm/relay/backend/contrib/ethosu/tir/passes.py | 10 +++-
 .../contrib/test_ethosu/test_encode_constants.py   | 66 ++++++++++++++++++++++
 2 files changed, 73 insertions(+), 3 deletions(-)

diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py
index 41a6832..0a6dcd1 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py
@@ -331,11 +331,15 @@ def EncodeConstants(const_dict):
     def _new_buffer(old_buffer, new_value):
         """Create a new buffer and add the old buffer and its pointer to the
         rewriting maps."""
-        new_buffer = tvm.tir.decl_buffer((len(new_value),), str(new_value.dtype))
-        pointer_to_buffer[new_buffer.data] = new_buffer
+        if old_buffer in rewrite_buffer:
+            new_buffer = rewrite_buffer[old_buffer]
+        else:
+            new_buffer = tvm.tir.decl_buffer((len(new_value),), str(new_value.dtype))
+            pointer_to_buffer[new_buffer.data] = new_buffer
+            buffer_to_const[new_buffer] = new_value
+
         rewrite_buffer[old_buffer] = new_buffer
         rewrite_pointer[old_buffer.data] = new_buffer.data
-        buffer_to_const[new_buffer] = new_value
 
     def _visit_encode_pre(stmt):
         if isinstance(stmt, tvm.tir.Call):
diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py
index de8a7f9..7f5eeb1 100644
--- a/tests/python/contrib/test_ethosu/test_encode_constants.py
+++ b/tests/python/contrib/test_ethosu/test_encode_constants.py
@@ -110,6 +110,72 @@ def test_weight_stream_only():
 
 # fmt: off
 @tvm.script.ir_module
+class RereadWeights:
+    @T.prim_func
+    def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, ethosu_write: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
+        placeholder_3 = T.match_buffer(placeholder, [1, 16, 16, 32], dtype="int8")
+        buffer = T.match_buffer(placeholder_1, [304], dtype="uint8")
+        buffer_1 = T.match_buffer(placeholder_2, [80], dtype="uint8")
+        ethosu_write_1 = T.match_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8")
+        # body
+        placeholder_global = T.allocate([304], "uint8", "global", annotations={"disable_lower_builtin":True})
+        placeholder_d_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin":True})
+        T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer.data, 0), 304, T.load("uint8", placeholder_global, 0), dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1.data, 0), 80, T.load("uint8", placeholder_d_global, 0), dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 1, 8, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 304, 12, T.load("uint8", placeholder_d_global, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer.data, 0), 304, T.load("uint8", placeholder_global, 0), dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1.data, 0), 80, T.load("uint8", placeholder_d_global, 0), dtype="handle"))
+        T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, T.load("int8", placeholder_3.data, 256), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, T.load("int8", ethosu_write_1.data, 64), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 1, 8, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 304, 12, T.load("uint8", placeholder_d_global, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle"))
+    __tvm_meta__ = None
+# fmt: on
+
+
+def test_re_read_weights():
+    def _cascader(cached_func, const_dict, sch):
+        weights = cached_func.inputs[1]
+        bias = cached_func.inputs[2]
+        out = cached_func.outputs[0]
+        conv_compute = Convolution2DCompute.from_output(out)
+        co = conv_compute.split(sch, 2, 8)
+        cache_weights = sch.cache_read(weights, "global", [conv_compute.conv2d])
+        cache_bias = sch.cache_read(bias, "global", [conv_compute.conv2d])
+        sch[cache_weights].compute_at(sch[out], co)
+        sch[cache_bias].compute_at(sch[out], co)
+
+    def _get_func():
+        ifm = relay.var("ifm", shape=(1, 16, 16, 32), dtype="int8")
+        conv = make_ethosu_conv2d(
+            ifm,
+            32,
+            8,
+            (1, 1),
+            (0, 0),
+            (1, 1),
+            (1, 1),
+        )
+        func = relay.Function(relay.analysis.free_vars(conv), conv)
+        func = run_opt_pass(func, relay.transform.InferType())
+        return func
+
+    func = _get_func()
+    mod, consts = lower_to_tir(func, cascader=_cascader)
+    script = mod.script(show_meta=True)
+    test_mod = tvm.script.from_source(script)
+    reference_mod = RereadWeights
+    tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True)
+
+    reference_const_sizes = {1: 304, 2: 80}
+    test_const_sizes = {}
+    for key, value in consts.items():
+        test_const_sizes[key] = len(value)
+
+    assert reference_const_sizes == test_const_sizes
+
+
+# fmt: off
+@tvm.script.ir_module
 class DirectReadOnly:
     @T.prim_func
     def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, placeholder_3: T.handle, placeholder_4: T.handle, ethosu_write: T.handle) -> None: