You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by cs...@apache.org on 2022/12/01 18:56:37 UTC
[tvm] branch main updated: [TIR][Analysis][Hexagon] Add vtcm memory capacity verification for Hexagon target (#13349)
This is an automated email from the ASF dual-hosted git repository.
csullivan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new afbfb7aa7e [TIR][Analysis][Hexagon] Add vtcm memory capacity verification for Hexagon target (#13349)
afbfb7aa7e is described below
commit afbfb7aa7e43732cb716f8e443df696110be6afc
Author: Alexey Voronov <av...@gmail.com>
AuthorDate: Thu Dec 1 21:56:31 2022 +0300
[TIR][Analysis][Hexagon] Add vtcm memory capacity verification for Hexagon target (#13349)
The main items that have been added are:
* tvm.tir.analysis.calculate_allocated_bytes(), to calculate allocated memory per memory scope
* tir.transform.VerifyVTCMLimit(limit), to verify if the size of the allocated vtcm memory satisfies the limit
* tvm.target.hexagon().vtcm_capacity, attribute to pass the limit
* tir.vtcm_capacity, context configuration attribute to pass the limit alternatively
---
include/tvm/tir/analysis.h | 16 +++
python/tvm/autotvm/measure/measure_methods.py | 33 ++++--
python/tvm/target/target.py | 8 ++
python/tvm/tir/analysis/analysis.py | 16 +++
python/tvm/tir/transform/transform.py | 11 ++
src/auto_scheduler/feature.cc | 7 ++
src/auto_scheduler/search_policy/utils.h | 5 +
src/driver/driver_api.cc | 17 ++-
src/target/target_kind.cc | 1 +
src/tir/analysis/calculate_allocated_memory.cc | 117 +++++++++++++++++++++
.../python/contrib/test_hexagon/infrastructure.py | 4 +-
tests/python/contrib/test_hexagon/test_vtcm.py | 55 +++++++---
...test_tir_analysis_calculate_allocated_memory.py | 101 ++++++++++++++++++
13 files changed, 366 insertions(+), 25 deletions(-)
diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h
index e9796eca65..cb31a7e5ee 100644
--- a/include/tvm/tir/analysis.h
+++ b/include/tvm/tir/analysis.h
@@ -217,6 +217,12 @@ TVM_DLL size_t CalculateConstantBytes(const PrimFunc& func, const Integer& const
TVM_DLL size_t CalculateWorkspaceBytes(const PrimFunc& func,
const Integer& workspace_byte_alignment);
+/*!
+ * \brief Calculate the allocated memory per scope in bytes needed inside the TIR PrimFunc
+ * \param func The TIR PrimFunc for which the the allocated memory size to be calculated
+ */
+TVM_DLL tvm::Map<String, Integer> CalculateAllocatedBytes(const PrimFunc& func);
+
/*!
* \brief Detect the lowest common ancestor(LCA) of buffer access, including both high-level
* access(BufferLoad, BufferStore) and low-level access(Load, Store and opaque access).
@@ -294,6 +300,16 @@ TVM_DLL Pass VerifyMemory();
*/
TVM_DLL Pass VerifyGPUCode(Map<String, PrimExpr> constraints);
+/*!
+ * \brief Pass to checks if the size of the allocated vtcm memory satisfies the limit
+ *
+ * \param limit The limit to check.
+ *
+ * \returns The pass.
+ * \sa tvm::tir::CalculateAllocatedBytes
+ */
+TVM_DLL Pass VerifyVTCMLimit(const Integer& limit);
+
/*!
* \brief Statically check TIR code for out of bounds array access.
*
diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py
index 8fc0da89c4..f1c14c3cd9 100644
--- a/python/tvm/autotvm/measure/measure_methods.py
+++ b/python/tvm/autotvm/measure/measure_methods.py
@@ -330,7 +330,7 @@ class RPCRunner(Runner):
)
def get_build_kwargs(self):
- kwargs = {}
+ kwargs = {"checks": {}}
if (
"cuda" in self.task.target.keys
or "opencl" in self.task.target.keys
@@ -340,13 +340,15 @@ class RPCRunner(Runner):
remote = request_remote(self.key, self.host, self.port)
dev = remote.device(str(self.task.target), 0)
max_dims = dev.max_thread_dimensions
- kwargs["check_gpu"] = {
+ kwargs["checks"]["gpu"] = {
"max_shared_memory_per_block": dev.max_shared_memory_per_block,
"max_threads_per_block": dev.max_threads_per_block,
"max_thread_x": max_dims[0],
"max_thread_y": max_dims[1],
"max_thread_z": max_dims[2],
}
+ if "hexagon" in self.task.target.keys:
+ kwargs["checks"]["hexagon"] = {"vtcm_capacity": self.task.target.vtcm_capacity}
return kwargs
@@ -493,11 +495,11 @@ class LocalRunner(RPCRunner):
return server, tracker
-def _build_func_common(measure_input, runtime=None, check_gpu=None, build_option=None):
+def _build_func_common(measure_input, runtime=None, checks=None, build_option=None):
"""Common part for building a configuration"""
target, task, config = measure_input
target, task.target_host = Target.canon_target_and_host(target, task.target_host)
-
+ checks = checks or {}
with target:
s, args = task.instantiate(config)
@@ -526,8 +528,10 @@ def _build_func_common(measure_input, runtime=None, check_gpu=None, build_option
current_add_lower_pass = list(current_config["tir.add_lower_pass"])
else:
current_add_lower_pass = []
- if check_gpu:
- current_add_lower_pass.append((2, gpu_verify_pass(**check_gpu)))
+ if checks.get("gpu"):
+ current_add_lower_pass.append((2, gpu_verify_pass(**checks.get("gpu"))))
+ if checks.get("hexagon"):
+ current_add_lower_pass.append((2, vtcm_verify_pass(**checks.get("hexagon"))))
current_config["tir.add_lower_pass"] = current_add_lower_pass
with tvm.ir.transform.PassContext(
@@ -872,3 +876,20 @@ def gpu_verify_pass(**kwargs):
return f
return tvm.tir.transform.prim_func_pass(verify_pass, opt_level=0)
+
+
+def vtcm_verify_pass(**kwargs):
+ """Verify the validity of a hexagon kernel.
+ This pass will check vtcm memory usage.
+ """
+
+ def verify_pass(f, *_):
+ sizes = tvm.tir.analysis.calculate_allocated_bytes(f)
+ vtcm_capacity = kwargs.get("vtcm_capacity", 0)
+ vtcm_allocated = sizes.get("global.vtcm", 0)
+ if 0 < vtcm_capacity < vtcm_allocated:
+ raise InstantiationError("Skipped because of invalid vtcm memory usage limit")
+
+ return f
+
+ return tvm.tir.transform.prim_func_pass(verify_pass, opt_level=0)
diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py
index 7081f992af..06e1776965 100644
--- a/python/tvm/target/target.py
+++ b/python/tvm/target/target.py
@@ -182,6 +182,10 @@ class Target(Object):
def max_function_args(self):
return int(self.attrs.get("max_function_args", -1))
+ @property
+ def vtcm_capacity(self):
+ return int(self.attrs.get("vtcm-capacity", 0))
+
@property
def device_name(self):
return str(self.attrs.get("device", ""))
@@ -642,6 +646,8 @@ def hexagon(cpu_ver="v66", **kwargs):
Whether to use IEEE HVX instructions
num_cores : int (default: 4)
The number of HVX threads. This attribute is required by meta scheduler.
+ vtcm_capacity: int (default: 0)
+ Hexagon VTCM capacity limitation. If the value is 0, the capacity is treated as unbounded.
Note: Floating point support in HVX requires LLVM 14+.
"""
@@ -675,6 +681,7 @@ def hexagon(cpu_ver="v66", **kwargs):
"llvm_options": None,
"use_qfloat": arch_version >= 68,
"use_ieee_fp": False,
+ "vtcm_capacity": 0,
}
config.update(kwargs)
@@ -748,6 +755,7 @@ def hexagon(cpu_ver="v66", **kwargs):
num_cores = config["num_cores"] if "num_cores" in kwargs else 4
args_list.append("--num-cores=%d" % num_cores)
+ args_list.append("--vtcm-capacity=%d" % config["vtcm_capacity"])
return Target(" ".join(["hexagon"] + args_list))
diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py
index efb869efd6..45b1f745c3 100644
--- a/python/tvm/tir/analysis/analysis.py
+++ b/python/tvm/tir/analysis/analysis.py
@@ -201,6 +201,22 @@ def calculate_constant_bytes(func: PrimFunc, constant_byte_alignment: int) -> in
return _ffi_api.calculate_constant_bytes(func, constant_byte_alignment) # type: ignore
+def calculate_allocated_bytes(func: PrimFunc) -> Dict[str, int]:
+ """Calculate allocated memory per memory scope required by TIR PrimFuncs.
+
+ Parameters
+ ----------
+ func: tvm.tir.PrimFunc
+ The function to be detected.
+
+ Returns
+ -------
+ result : Dict[String, int]
+ Allocated memory size per scope in bytes.
+ """
+ return _ffi_api.calculate_allocated_bytes(func) # type: ignore
+
+
def detect_buffer_access_lca(func: PrimFunc) -> Dict[Buffer, Stmt]:
"""Detect the lowest common ancestor(LCA) of buffer access, including both high-level
access(BufferLoad, BufferStore) and low-level access(Load, Store and opaque access).
diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py
index 82533a2f9f..81b90d5f40 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -611,6 +611,17 @@ def VerifyMemory():
return _ffi_api.VerifyMemory() # type: ignore
+def VerifyVTCMLimit(limit: int):
+ """Verify if the size of the allocated vtcm memory satisfies the limit.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.VerifyVTCMLimit(limit) # type: ignore
+
+
# pylint: disable=no-else-return,inconsistent-return-statements
def HoistIfThenElse(variant: Optional[str] = None):
"""Hoist loop-invariant IfThenElse nodes to outside the eligible loops.
diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc
index 2f993c0c8b..4ce7ad13bc 100644
--- a/src/auto_scheduler/feature.cc
+++ b/src/auto_scheduler/feature.cc
@@ -1401,6 +1401,13 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask& task, const State& state, i
const auto& optimize = tir::transform::Sequential(pass_list);
optimize(mod);
}
+ if (IsHexagonTask(task)) {
+ Target target = task->target;
+ const auto vtcm_capacity = target->GetAttr<Integer>("vtcm-capacity").value().IntValue();
+ const auto& optimize =
+ tir::transform::Sequential({tir::transform::VerifyVTCMLimit(vtcm_capacity)});
+ optimize(mod);
+ }
const auto& optimize =
tir::transform::Sequential(Array<tvm::transform::Pass>{tir::transform::Simplify()});
mod = optimize(std::move(mod));
diff --git a/src/auto_scheduler/search_policy/utils.h b/src/auto_scheduler/search_policy/utils.h
index 44b60de1d7..ca8979c0e8 100644
--- a/src/auto_scheduler/search_policy/utils.h
+++ b/src/auto_scheduler/search_policy/utils.h
@@ -58,6 +58,11 @@ inline bool IsGPUTask(const SearchTask& task) {
device_type == kDLMetal || device_type == kDLROCM || device_type == kOpenGL;
}
+/*! \brief Return whether the search task is targeting a Hexagon. */
+inline bool IsHexagonTask(const SearchTask& task) {
+ return (task)->target->GetTargetDeviceType() == kDLHexagon;
+}
+
/*! \brief Return whether the search task is targeting a CUDA GPU. */
inline bool IsCUDATask(const SearchTask& task) {
return (task)->target->GetTargetDeviceType() == kDLCUDA;
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index 90676e0b84..10d9e8023a 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -54,6 +54,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.merge_async_commit_queue_scope", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_lwp", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.dma_bypass_cache", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.vtcm_capacity", Integer);
using tvm::Array;
using tvm::transform::Pass;
@@ -225,8 +226,6 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
if (!disable_storage_rewrite) {
pass_list.push_back(tir::transform::StorageRewrite());
}
- // LowerVtcmAlloc must occur after any transformations that modify memory allocation locations
- pass_list.push_back(tir::transform::LowerVtcmAlloc());
bool use_async_copy = pass_ctx->GetConfig<Bool>("tir.use_async_copy", Bool(false)).value();
if (use_async_copy) {
@@ -532,11 +531,25 @@ runtime::Module build(const IRModule& funcs, const Target& target_arg,
return TIRToRuntime(inputs, target_host);
}
+int64_t GetVTCMCapacity(Target target, const transform::PassContext& pass_ctx) {
+ if (!target.defined()) target = Target::Current(/*allow_not_defined=*/true);
+ if (target.defined() && target->kind->name == "hexagon") {
+ auto value = Downcast<Integer>(target->attrs.at("vtcm-capacity"))->value;
+ if (value > 0) return value;
+ }
+ return pass_ctx->GetConfig<Integer>("tir.vtcm_capacity", Integer(0)).value()->value;
+}
+
transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) {
transform::PassContext pass_ctx = transform::PassContext::Current();
Array<Pass> mixed_pass_list;
+ // VerifyVTCMLimit must occur before LowerVtcmAlloc
+ mixed_pass_list.push_back(tir::transform::VerifyVTCMLimit(GetVTCMCapacity(target, pass_ctx)));
+ // LowerVtcmAlloc must occur after any transformations that modify memory allocation locations
+ mixed_pass_list.push_back(tir::transform::LowerVtcmAlloc());
+
mixed_pass_list.push_back(tir::transform::BindTarget(target));
mixed_pass_list.push_back(tir::transform::VerifyMemory());
diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc
index ef350004ad..a87bb92c48 100644
--- a/src/target/target_kind.cc
+++ b/src/target/target_kind.cc
@@ -421,6 +421,7 @@ TVM_REGISTER_TARGET_KIND("hexagon", kDLHexagon)
.add_attr_option<String>("mtriple")
.add_attr_option<Array<String>>("llvm-options")
.add_attr_option<Integer>("num-cores")
+ .add_attr_option<Integer>("vtcm-capacity")
.set_default_keys({"hexagon"});
TVM_REGISTER_TARGET_KIND("stackvm", kDLCPU);
diff --git a/src/tir/analysis/calculate_allocated_memory.cc b/src/tir/analysis/calculate_allocated_memory.cc
new file mode 100644
index 0000000000..01457508ab
--- /dev/null
+++ b/src/tir/analysis/calculate_allocated_memory.cc
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/analysis/calculate_allocated_memory.cc
+ * \brief Calculate allocated memory per memory scope required by PrimFuncs.
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/container/map.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/usmp/utils.h>
+
+#include <algorithm>
+#include <map>
+#include <unordered_map>
+
+namespace tvm {
+namespace tir {
+
+template <typename T>
+class AllocationCalculator : public StmtExprVisitor {
+ public:
+ AllocationCalculator() = default;
+ tvm::Map<String, Integer> operator()(const PrimFunc& func);
+
+ private:
+ void VisitStmt_(const T* op) override;
+ std::unordered_map<std::string, int64_t> _max_size;
+ std::unordered_map<std::string, int64_t> _current_size;
+};
+
+template <typename T>
+tvm::Map<String, Integer> AllocationCalculator<T>::operator()(const PrimFunc& func) {
+ this->VisitStmt(func->body);
+ tvm::Map<String, Integer> res;
+ for (auto [k, v] : _max_size) {
+ res.Set(String(k), Integer(v));
+ }
+ return res;
+}
+
+std::string GetStorageScope(const Var& var) {
+ auto* ptr = var->type_annotation.as<PointerTypeNode>();
+ ICHECK(ptr) << "Buffer Var's type annotation must be of PointerType";
+ return ptr->storage_scope;
+}
+
+template <typename T>
+void AllocationCalculator<T>::VisitStmt_(const T* op) {
+ std::string storage_scope = GetStorageScope(op->buffer_var);
+ auto search = _current_size.find(storage_scope);
+ if (search == _current_size.end()) {
+ _current_size[storage_scope] = 0;
+ _max_size[storage_scope] = 0;
+ }
+ auto size = op->ConstantAllocationSize() * op->dtype.bytes() * op->dtype.lanes();
+ _current_size[storage_scope] += size;
+ _max_size[storage_scope] = std::max(_current_size[storage_scope], _max_size[storage_scope]);
+ StmtExprVisitor::VisitStmt(op->body);
+ _current_size[storage_scope] -= size;
+}
+
+tvm::Map<String, Integer> CalculateAllocatedBytes(const PrimFunc& func) {
+ return AllocationCalculator<AllocateNode>()(func);
+}
+
+TVM_REGISTER_GLOBAL("tir.analysis.calculate_allocated_bytes").set_body_typed([](PrimFunc func) {
+ return CalculateAllocatedBytes(func);
+});
+
+namespace transform {
+
+Pass VerifyVTCMLimit(const Integer& limit) {
+ auto pass_func = [=](IRModule mod, PassContext ctx) {
+ for (auto kv : mod->functions) {
+ if (auto* n = kv.second.as<PrimFuncNode>()) {
+ auto func = GetRef<PrimFunc>(n);
+ auto sizes = CalculateAllocatedBytes(func);
+ const auto vtcm_allocated = sizes.Get("global.vtcm").value_or(0);
+ if (limit.IntValue() > 0 && vtcm_allocated.IntValue() > limit.IntValue()) {
+ LOG(FATAL) << "RuntimeError: The global.vtcm memory allocation limit has been "
+ "exceeded(allocated: "
+ << vtcm_allocated << ", limit: " << limit << ").\n"
+ << "In function\n"
+ << func;
+ }
+ }
+ }
+ return mod;
+ };
+ return tvm::transform::CreateModulePass(pass_func, 0, "tir.calculate_allocated_bytes", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.VerifyVTCMLimit").set_body_typed(VerifyVTCMLimit);
+
+} // namespace transform
+} // namespace tir
+} // namespace tvm
diff --git a/tests/python/contrib/test_hexagon/infrastructure.py b/tests/python/contrib/test_hexagon/infrastructure.py
index 5b13513c0f..fcb811fce7 100644
--- a/tests/python/contrib/test_hexagon/infrastructure.py
+++ b/tests/python/contrib/test_hexagon/infrastructure.py
@@ -324,7 +324,7 @@ def quantize_np(arr_np: numpy.ndarray, dtype: str):
return quant_np, scale, zero_point
-def get_hexagon_target(cpu_ver: str) -> tvm.target.Target:
+def get_hexagon_target(cpu_ver: str, **kwargs) -> tvm.target.Target:
"""Creates a Hexagon target"""
- target = tvm.target.hexagon(cpu_ver)
+ target = tvm.target.hexagon(cpu_ver, **kwargs)
return tvm.target.Target(target, host=target)
diff --git a/tests/python/contrib/test_hexagon/test_vtcm.py b/tests/python/contrib/test_hexagon/test_vtcm.py
index 11188436a3..e71f890740 100644
--- a/tests/python/contrib/test_hexagon/test_vtcm.py
+++ b/tests/python/contrib/test_hexagon/test_vtcm.py
@@ -16,9 +16,11 @@
# under the License.
"""VTCM Tests"""
+import pytest
import tvm.testing
from tvm import tir
from tvm.script import tir as T
+from .infrastructure import get_hexagon_target
@T.prim_func
@@ -31,8 +33,7 @@ def scale_by_two(buffer_a: T.Buffer[(8192,), "int8"], buffer_c: T.Buffer[(8192,)
buffer_c[i] = buffer_a[i] * T.int8(2)
-def test_vtcm_lowering():
- """Test lowering with vtcm mem scope"""
+def get_scale_by_two_schedule():
mod = tvm.IRModule.from_expr(scale_by_two.with_attr("global_symbol", "main"))
sch = tir.Schedule(mod, debug_mask="all")
block_c = sch.get_block("C")
@@ -40,23 +41,47 @@ def test_vtcm_lowering():
outer, _, _, _ = sch.split(flat, factors=[8, 4, 2, 128])
cache_block = sch.cache_read(block_c, 0, storage_scope="global.vtcm")
sch.compute_at(cache_block, outer)
- lowered = tvm.lower(sch.mod["main"])
+ return sch
- def ir_module_has_allocate_nodes(irmod):
- nallocs = 0
- def _visit(stmt):
- nonlocal nallocs
- if isinstance(stmt, tvm.tir.Allocate):
- nallocs += 1
+@tvm.testing.requires_hexagon
+def test_vtcm_building():
+ """Test building with vtcm mem scope"""
+ sch = get_scale_by_two_schedule()
+ target = get_hexagon_target("v68")
+ built = tvm.build(sch.mod, target=target)
+ assert "global.vtcm" in built.get_source("asm")
- tvm.tir.stmt_functor.post_order_visit(irmod["main"].body, _visit)
- return nallocs
- assert not ir_module_has_allocate_nodes(lowered), (
- "AllocateNode found in lowered IRModule, "
- "VTCM allocations should have been lowered to tir.nd_mem_alloc_with_scope"
- )
+@tvm.testing.requires_hexagon
+@pytest.mark.parametrize("vtcm_capacity,limited", [(8192, False), (1024, False), (128, True)])
+def test_vtcm_limit(vtcm_capacity, limited):
+ """Test building with vtcm mem scope limit"""
+ sch = get_scale_by_two_schedule()
+
+ def _raises_exception(f):
+ try:
+ f()
+ except tvm._ffi.base.TVMError:
+ return True
+ return False
+
+ target = get_hexagon_target("v68", vtcm_capacity=vtcm_capacity)
+
+ assert (
+ _raises_exception(lambda: tvm.build(sch.mod, target=target)) == limited
+ ), "Case 1 - arg. VTCM memory allocation limiter does not work correctly "
+
+ with target:
+ assert (
+ _raises_exception(lambda: tvm.build(sch.mod)) == limited
+ ), "Case 2 - with.VTCM memory allocation limiter does not work correctly "
+
+ with tvm.transform.PassContext(config={"tir.vtcm_capacity": vtcm_capacity}):
+ assert (
+ _raises_exception(lambda: tvm.build(sch.mod, target=get_hexagon_target("v68")))
+ == limited
+ ), "Case 3 - context. VTCM memory allocation limiter does not work correctly "
if __name__ == "__main__":
diff --git a/tests/python/unittest/test_tir_analysis_calculate_allocated_memory.py b/tests/python/unittest/test_tir_analysis_calculate_allocated_memory.py
new file mode 100644
index 0000000000..1a2d50ef5d
--- /dev/null
+++ b/tests/python/unittest/test_tir_analysis_calculate_allocated_memory.py
@@ -0,0 +1,101 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+
+import tvm
+from tvm import tir
+from tvm.script import tir as T
+
+
+@T.prim_func
+def scale_by_two(a: T.Buffer[(128,), "int8"], c: T.Buffer[(128,), "int8"]):
+ for i in T.serial(128):
+ with T.block("C"):
+ c[i] = a[i] * T.int8(2)
+
+
+@T.prim_func
+def scale_by_two_three(a: T.Buffer[(128,), "int8"], c: T.Buffer[(128,), "int8"]):
+ B = T.alloc_buffer([128], dtype="int8", scope="global.vtcm")
+ for i in T.serial(128):
+ with T.block("B"):
+ B[i] = a[i] * T.int8(2)
+ for i in T.serial(128):
+ with T.block("C"):
+ c[i] = B[i] * T.int8(3)
+
+
+@pytest.mark.parametrize("primFunc,size", [(scale_by_two, 128), (scale_by_two_three, 256)])
+def test_scale_by(primFunc, size):
+ """Test calculate allocated bytes per scope"""
+ mod = tvm.IRModule.from_expr(primFunc.with_attr("global_symbol", "main"))
+ sch = tir.Schedule(mod, debug_mask="all")
+ block_c = sch.get_block("C")
+ (flat,) = sch.get_loops(block_c)
+ cache_block = sch.cache_read(block_c, 0, storage_scope="global.vtcm")
+ sch.compute_at(cache_block, flat)
+
+ mod = sch.mod
+ mod = tvm.tir.transform.ConvertBlocksToOpaque()(mod)
+ mod = tvm.tir.transform.LowerOpaqueBlock()(mod)
+ sizes = tvm.tir.analysis.calculate_allocated_bytes(mod["main"])
+ assert sizes.get("global.vtcm", 0) == size
+
+
+@T.prim_func
+def matmul_mix_scope(a: T.handle, b: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, [128, 128], scope="global")
+ B = T.match_buffer(b, [128, 128], scope="global")
+ C = T.match_buffer(c, [128, 128], scope="global")
+ A_allocated = T.alloc_buffer([128, 128], dtype="float32", scope="global.texture")
+ B_allocated = T.alloc_buffer([128, 128], dtype="float32", scope="global.texture")
+ C_allocated = T.alloc_buffer([128, 128], dtype="float32", scope="global")
+
+ for i, j in T.grid(128, 128):
+ with T.block("A.allocated"):
+ A_allocated[i, j] = A[i, j]
+ for i, j in T.grid(128, 128):
+ with T.block("B.allocated"):
+ B_allocated[i, j] = B[i, j]
+
+ for i, j, k in T.grid(128, 128, 128):
+ with T.block("update"):
+ vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+ with T.init():
+ C_allocated[vi, vj] = 0.0
+ C_allocated[vi, vj] = C[vi, vj] + A_allocated[vi, vk] * B_allocated[vj, vk]
+
+ for i, j in T.grid(128, 128):
+ with T.block("C"):
+ C[i, j] = C_allocated[i, j]
+
+
+@pytest.mark.parametrize(
+ "scope,size", [("global", 65536), ("global.texture", 131072), ("global.texture-nhwc", 0)]
+)
+def test_matmul_mix_scope(scope, size):
+ """Test calculate allocated bytes per scope"""
+ mod = tvm.IRModule({"main": matmul_mix_scope})
+ mod = tvm.tir.transform.LowerInitBlock()(mod)
+ mod = tvm.tir.transform.ConvertBlocksToOpaque()(mod)
+ mod = tvm.tir.transform.LowerOpaqueBlock()(mod)
+ sizes = tvm.tir.analysis.calculate_allocated_bytes(mod["main"])
+ assert sizes.get(scope, 0) == size
+
+
+if __name__ == "__main__":
+ tvm.testing.main()