You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2023/03/29 19:12:01 UTC
[tvm] branch unity updated: [Unity] Include constant shapes in the profiler result (#14428)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 48d972a8b5 [Unity] Include constant shapes in the profiler result (#14428)
48d972a8b5 is described below
commit 48d972a8b560d3ca198d8b746b82591cd083c1fc
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Wed Mar 29 12:11:52 2023 -0700
[Unity] Include constant shapes in the profiler result (#14428)
---
src/runtime/relax_vm/vm.cc | 19 +++++++++++++------
1 file changed, 13 insertions(+), 6 deletions(-)
diff --git a/src/runtime/relax_vm/vm.cc b/src/runtime/relax_vm/vm.cc
index 33a19c0482..2e6c341213 100644
--- a/src/runtime/relax_vm/vm.cc
+++ b/src/runtime/relax_vm/vm.cc
@@ -361,7 +361,6 @@ class VirtualMachineImpl : public VirtualMachine {
void ClearInputsFor(const std::string& func_name) { inputs_.erase(func_name); }
- private:
//--------------------------------------------------------
// Internal states for execution.
//--------------------------------------------------------
@@ -983,15 +982,23 @@ class VirtualMachineProfiler : public VirtualMachineImpl {
auto f_name = GetFuncName(inst.func_idx);
std::optional<Device> dev;
std::vector<NDArray> arrs;
+
+ auto f_check_ndarray_arg = [&dev, &arrs](const RegType& arg) {
+ if (arg.type_code() == kTVMNDArrayHandle) {
+ NDArray arr = arg;
+ dev = arr->device;
+ arrs.push_back(arr);
+ }
+ };
+
for (Index i = 0; i < inst.num_args; ++i) {
Instruction::Arg arg = inst.args[i];
if (arg.kind() == Instruction::ArgKind::kRegister) {
auto reg = ReadRegister(curr_frame, arg.value());
- if (reg.type_code() == kTVMNDArrayHandle) {
- NDArray arr = reg;
- dev = arr->device;
- arrs.push_back(arr);
- }
+ f_check_ndarray_arg(reg);
+ } else if (arg.kind() == Instruction::ArgKind::kConstIdx) {
+ const auto& const_val = this->const_pool_[arg.value()];
+ f_check_ndarray_arg(const_val);
}
}