You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by cs...@apache.org on 2022/05/13 20:22:02 UTC
[tvm] branch main updated: Avoid use of MemoryInfo when undefined in StorageRewrite (#11254)
This is an automated email from the ASF dual-hosted git repository.
csullivan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 19ce068149 Avoid use of MemoryInfo when undefined in StorageRewrite (#11254)
19ce068149 is described below
commit 19ce0681498dbf409ebb12906acf5712bf8c7ea7
Author: Chris Sullivan <cs...@octoml.ai>
AuthorDate: Fri May 13 13:21:57 2022 -0700
Avoid use of MemoryInfo when undefined in StorageRewrite (#11254)
* Check if the requested memory info is defined before using it.
* Address review comment to add warning when MemoryInfo
for scope is undefined.
---
src/target/target_info.cc | 1 +
src/tir/transforms/storage_rewrite.cc | 16 ++++++++++------
2 files changed, 11 insertions(+), 6 deletions(-)
diff --git a/src/target/target_info.cc b/src/target/target_info.cc
index 5ebb7edc80..d83c8beac4 100644
--- a/src/target/target_info.cc
+++ b/src/target/target_info.cc
@@ -42,6 +42,7 @@ MemoryInfo GetMemoryInfo(const std::string& scope) {
std::string fname = "tvm.info.mem." + scope;
const runtime::PackedFunc* f = runtime::Registry::Get(fname);
if (f == nullptr) {
+ LOG(WARNING) << "MemoryInfo for scope = " << scope << " is undefined";
return MemoryInfo();
} else {
return (*f)();
diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc
index 27a4d74100..c5f27b8de3 100644
--- a/src/tir/transforms/storage_rewrite.cc
+++ b/src/tir/transforms/storage_rewrite.cc
@@ -661,9 +661,11 @@ class StoragePlanRewriter : public StmtExprMutator {
e->allocs[0]->condition, Evaluate(0));
if (IsSpecialTaggedMemory(e->scope)) {
MemoryInfo info = GetMemoryInfo(e->scope.to_string());
- uint64_t total_elem = e->const_nbits / e->elem_type.bits();
- ICHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits)
- << "Allocation exceed bound of memory tag " << e->scope.to_string();
+ if (info.defined()) {
+ uint64_t total_elem = e->const_nbits / e->elem_type.bits();
+ ICHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits)
+ << "Allocation exceed bound of memory tag " << e->scope.to_string();
+ }
}
} else {
// Build a merged allocation
@@ -707,9 +709,11 @@ class StoragePlanRewriter : public StmtExprMutator {
Allocate(e->alloc_var, alloc_type, {combo_size}, const_true(), Evaluate(0));
if (IsSpecialTaggedMemory(e->scope)) {
MemoryInfo info = GetMemoryInfo(e->scope.to_string());
- uint64_t total_elem = e->const_nbits / e->elem_type.bits();
- ICHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits)
- << "Allocation exceed bound of memory tag " << e->scope.to_string();
+ if (info.defined()) {
+ uint64_t total_elem = e->const_nbits / e->elem_type.bits();
+ ICHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits)
+ << "Allocation exceed bound of memory tag " << e->scope.to_string();
+ }
}
}
}