You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/12/04 07:44:16 UTC
[tvm] branch main updated: [microNPU] Fix bug with re-reading in EncodeConstants (#9646)
This is an automated email from the ASF dual-hosted git repository.
manupa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new befa7f3 [microNPU] Fix bug with re-reading in EncodeConstants (#9646)
befa7f3 is described below
commit befa7f33587068b6bc2ee7f57dc726d7bb1dc365
Author: Matthew Barrett <55...@users.noreply.github.com>
AuthorDate: Sat Dec 4 07:43:56 2021 +0000
[microNPU] Fix bug with re-reading in EncodeConstants (#9646)
When a striping strategy that leads to weights
being re-read was deployed, the logic in EncodeConstants
failed. This adds a test for that case and fixed the
pass so it handles it correctly.
Change-Id: I6f54cdb7be69428e49c3b4208271cd3e6c192e5d
---
.../tvm/relay/backend/contrib/ethosu/tir/passes.py | 10 +++-
.../contrib/test_ethosu/test_encode_constants.py | 66 ++++++++++++++++++++++
2 files changed, 73 insertions(+), 3 deletions(-)
diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py
index 41a6832..0a6dcd1 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py
@@ -331,11 +331,15 @@ def EncodeConstants(const_dict):
def _new_buffer(old_buffer, new_value):
"""Create a new buffer and add the old buffer and its pointer to the
rewriting maps."""
- new_buffer = tvm.tir.decl_buffer((len(new_value),), str(new_value.dtype))
- pointer_to_buffer[new_buffer.data] = new_buffer
+ if old_buffer in rewrite_buffer:
+ new_buffer = rewrite_buffer[old_buffer]
+ else:
+ new_buffer = tvm.tir.decl_buffer((len(new_value),), str(new_value.dtype))
+ pointer_to_buffer[new_buffer.data] = new_buffer
+ buffer_to_const[new_buffer] = new_value
+
rewrite_buffer[old_buffer] = new_buffer
rewrite_pointer[old_buffer.data] = new_buffer.data
- buffer_to_const[new_buffer] = new_value
def _visit_encode_pre(stmt):
if isinstance(stmt, tvm.tir.Call):
diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py
index de8a7f9..7f5eeb1 100644
--- a/tests/python/contrib/test_ethosu/test_encode_constants.py
+++ b/tests/python/contrib/test_ethosu/test_encode_constants.py
@@ -110,6 +110,72 @@ def test_weight_stream_only():
# fmt: off
@tvm.script.ir_module
+class RereadWeights:
+ @T.prim_func
+ def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, ethosu_write: T.handle) -> None:
+ # function attr dict
+ T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
+ placeholder_3 = T.match_buffer(placeholder, [1, 16, 16, 32], dtype="int8")
+ buffer = T.match_buffer(placeholder_1, [304], dtype="uint8")
+ buffer_1 = T.match_buffer(placeholder_2, [80], dtype="uint8")
+ ethosu_write_1 = T.match_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8")
+ # body
+ placeholder_global = T.allocate([304], "uint8", "global", annotations={"disable_lower_builtin":True})
+ placeholder_d_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin":True})
+ T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer.data, 0), 304, T.load("uint8", placeholder_global, 0), dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1.data, 0), 80, T.load("uint8", placeholder_d_global, 0), dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 1, 8, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 304, 12, T.load("uint8", placeholder_d_global, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer.data, 0), 304, T.load("uint8", placeholder_global, 0), dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1.data, 0), 80, T.load("uint8", placeholder_d_global, 0), dtype="handle"))
+ T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, T.load("int8", placeholder_3.data, 256), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, T.load("int8", ethosu_write_1.data, 64), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 1, 8, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 304, 12, T.load("uint8", placeholder_d_global, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle"))
+ __tvm_meta__ = None
+# fmt: on
+
+
+def test_re_read_weights():
+ def _cascader(cached_func, const_dict, sch):
+ weights = cached_func.inputs[1]
+ bias = cached_func.inputs[2]
+ out = cached_func.outputs[0]
+ conv_compute = Convolution2DCompute.from_output(out)
+ co = conv_compute.split(sch, 2, 8)
+ cache_weights = sch.cache_read(weights, "global", [conv_compute.conv2d])
+ cache_bias = sch.cache_read(bias, "global", [conv_compute.conv2d])
+ sch[cache_weights].compute_at(sch[out], co)
+ sch[cache_bias].compute_at(sch[out], co)
+
+ def _get_func():
+ ifm = relay.var("ifm", shape=(1, 16, 16, 32), dtype="int8")
+ conv = make_ethosu_conv2d(
+ ifm,
+ 32,
+ 8,
+ (1, 1),
+ (0, 0),
+ (1, 1),
+ (1, 1),
+ )
+ func = relay.Function(relay.analysis.free_vars(conv), conv)
+ func = run_opt_pass(func, relay.transform.InferType())
+ return func
+
+ func = _get_func()
+ mod, consts = lower_to_tir(func, cascader=_cascader)
+ script = mod.script(show_meta=True)
+ test_mod = tvm.script.from_source(script)
+ reference_mod = RereadWeights
+ tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True)
+
+ reference_const_sizes = {1: 304, 2: 80}
+ test_const_sizes = {}
+ for key, value in consts.items():
+ test_const_sizes[key] = len(value)
+
+ assert reference_const_sizes == test_const_sizes
+
+
+# fmt: off
+@tvm.script.ir_module
class DirectReadOnly:
@T.prim_func
def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, placeholder_3: T.handle, placeholder_4: T.handle, ethosu_write: T.handle) -> None: