You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2023/03/29 19:12:01 UTC

[tvm] branch unity updated: [Unity] Include constant shapes in the profiler result (#14428)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 48d972a8b5 [Unity] Include constant shapes in the profiler result (#14428)
48d972a8b5 is described below

commit 48d972a8b560d3ca198d8b746b82591cd083c1fc
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Wed Mar 29 12:11:52 2023 -0700

    [Unity] Include constant shapes in the profiler result (#14428)
---
 src/runtime/relax_vm/vm.cc | 19 +++++++++++++------
 1 file changed, 13 insertions(+), 6 deletions(-)

diff --git a/src/runtime/relax_vm/vm.cc b/src/runtime/relax_vm/vm.cc
index 33a19c0482..2e6c341213 100644
--- a/src/runtime/relax_vm/vm.cc
+++ b/src/runtime/relax_vm/vm.cc
@@ -361,7 +361,6 @@ class VirtualMachineImpl : public VirtualMachine {
 
   void ClearInputsFor(const std::string& func_name) { inputs_.erase(func_name); }
 
- private:
   //--------------------------------------------------------
   // Internal states for execution.
   //--------------------------------------------------------
@@ -983,15 +982,23 @@ class VirtualMachineProfiler : public VirtualMachineImpl {
       auto f_name = GetFuncName(inst.func_idx);
       std::optional<Device> dev;
       std::vector<NDArray> arrs;
+
+      auto f_check_ndarray_arg = [&dev, &arrs](const RegType& arg) {
+        if (arg.type_code() == kTVMNDArrayHandle) {
+          NDArray arr = arg;
+          dev = arr->device;
+          arrs.push_back(arr);
+        }
+      };
+
       for (Index i = 0; i < inst.num_args; ++i) {
         Instruction::Arg arg = inst.args[i];
         if (arg.kind() == Instruction::ArgKind::kRegister) {
           auto reg = ReadRegister(curr_frame, arg.value());
-          if (reg.type_code() == kTVMNDArrayHandle) {
-            NDArray arr = reg;
-            dev = arr->device;
-            arrs.push_back(arr);
-          }
+          f_check_ndarray_arg(reg);
+        } else if (arg.kind() == Instruction::ArgKind::kConstIdx) {
+          const auto& const_val = this->const_pool_[arg.value()];
+          f_check_ndarray_arg(const_val);
         }
       }