You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/02/23 19:48:21 UTC

[GitHub] [tvm] tkonolige commented on a change in pull request #7500: [WIP][Pass] Profiling TVM compiler passes

tkonolige commented on a change in pull request #7500:
URL: https://github.com/apache/tvm/pull/7500#discussion_r581338476



##########
File path: src/ir/transform.cc
##########
@@ -169,6 +170,126 @@ void PassContext::Trace(const IRModule& module, const PassInfo& info, bool is_be
 
 class ModulePass;
 
+struct PassProfile {
+  typedef std::chrono::steady_clock Clock;
+  typedef std::chrono::microseconds Duration;
+  typedef std::chrono::time_point<Clock> Time;
+
+  String name;
+  Time start;
+  Time end;
+  Duration duration;
+  std::vector<PassProfile> children;
+
+  PassProfile(String name) : name(name), start(Clock::now()), end(Clock::now()), children() {}
+
+  static PassProfile* Current();
+  static void EnterPass(String name);
+  static void ExitPass();
+};
+
+struct PassProfileThreadLocalEntry {
+  // TODO: figure out the TVM way to do this
+  PassProfile root;
+  std::stack<PassProfile*> profile_stack;
+
+  PassProfileThreadLocalEntry() : root("root") {}
+};
+
+typedef dmlc::ThreadLocalStore<PassProfileThreadLocalEntry> PassProfileThreadLocalStore;
+
+void PassProfile::EnterPass(String name) {
+  PassProfile* cur = PassProfile::Current();
+  cur->children.emplace_back(name);
+  PassProfileThreadLocalStore::Get()->profile_stack.push(&cur->children.back());
+}
+
+void PassProfile::ExitPass() {
+  PassProfile* cur = PassProfile::Current();
+  ICHECK_NE(cur->name, "root");
+  cur->end = std::move(PassProfile::Clock::now());
+  cur->duration = std::chrono::duration_cast<PassProfile::Duration>(cur->end - cur->start);
+  PassProfileThreadLocalStore::Get()->profile_stack.pop();
+}
+
+PassProfile* PassProfile::Current() {
+  PassProfileThreadLocalEntry* entry = PassProfileThreadLocalStore::Get();
+  if (!entry->profile_stack.empty()) {
+    return entry->profile_stack.top();
+  } else {
+    return &entry->root;
+  }
+}
+
+IRModule Pass::operator()(IRModule mod) const {
+  const PassNode* node = operator->();
+  ICHECK(node != nullptr);
+  PassProfile::EnterPass(node->Info()->name);
+  auto ret = node->operator()(std::move(mod));
+  PassProfile::ExitPass();
+  return std::move(ret);
+}
+
+IRModule Pass::operator()(IRModule mod, const PassContext& pass_ctx) const {
+  const PassNode* node = operator->();
+  ICHECK(node != nullptr);
+  PassProfile::EnterPass(node->Info()->name);
+  auto ret = node->operator()(std::move(mod), pass_ctx);
+  PassProfile::ExitPass();
+  return std::move(ret);
+}
+
+void PrintPassProfile() {
+  PassProfileThreadLocalEntry* entry = PassProfileThreadLocalStore::Get();
+  ICHECK(entry->profile_stack.empty()) << "cannot print pass profile while still in a pass!";

Review comment:
       Lets make this user facing `CHECK`

##########
File path: src/ir/transform.cc
##########
@@ -169,6 +170,126 @@ void PassContext::Trace(const IRModule& module, const PassInfo& info, bool is_be
 
 class ModulePass;
 
+struct PassProfile {
+  typedef std::chrono::steady_clock Clock;
+  typedef std::chrono::microseconds Duration;
+  typedef std::chrono::time_point<Clock> Time;

Review comment:
       These should all be `using`. `using Clock = std::chronology::steady_clock`.
   
   Also, I think you could use `std::chrono::duration<double, std::micro>` if you want floating point values.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org