You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/06/25 12:03:55 UTC

[tvm] branch main updated: [AOT] Calculate used memory at the callsite of primitive functions (#11208)

This is an automated email from the ASF dual-hosted git repository.

manupa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6d6e070587 [AOT] Calculate used memory at the callsite of primitive functions (#11208)
6d6e070587 is described below

commit 6d6e0705873b0b64576127fd6038720ef6c9c338
Author: Luke Hutton <lu...@arm.com>
AuthorDate: Sat Jun 25 13:03:47 2022 +0100

    [AOT] Calculate used memory at the callsite of primitive functions (#11208)
    
    * [AOT] Calculate used memory at the callsite of primitive functions
    
    Introduces a new pass in the AOT executor called "AnnotateUsedMemory"
    which applies liveness analysis to the callsite of each primitive
    function in order to calculate the total size of the live tensors at
    this point of execution. The result is provided as a function annotation
    called "used_memory", which can be consumed by later stages of the
    compiler (e.g. external codegens) to provide more information about the
    current memory consumption. This can be useful for some optimizations.
    
    Change-Id: I8d6b7447498f19260358bbefe34029ddd86b9c89
    
    * small fix to file description
    
    Change-Id: I0e460f6cf43f9b12ffa5fc66fcb68e55304daeb2
    
    * Various improvements addressing comments
    
    In addition, a new "io_used_memory" annotation is added to the main
    function which refers to the total size of the IO tensors in the
    provided module, enabling these to be discounted from memory pressure
    calculations where necessary.
    
    Change-Id: Iafe9c85d7fc69c77a2115ed4efe7645160387c86
    
    * addressing comments
    
    Change-Id: I00f5ba80d5e004076e4c27d39bec143178b3b1dd
    
    * add note for dynamic shapes
    
    Change-Id: If6409e2953addfc880bcc6d95083b78bdf5a23d0
---
 include/tvm/relay/transform.h                    |  13 +
 src/relay/backend/annotate_used_memory.cc        | 233 ++++++++++++
 src/relay/backend/aot_executor_codegen.cc        |   2 +
 src/relay/backend/liveness_analysis.cc           | 232 ++++++++++++
 src/relay/backend/liveness_analysis.h            | 270 ++++++++++++++
 src/relay/backend/vm/manifest_lifetimes.cc       | 388 +-------------------
 tests/python/relay/test_used_memory_annotator.py | 434 +++++++++++++++++++++++
 7 files changed, 1185 insertions(+), 387 deletions(-)

diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h
index b592265c74..1fef02557e 100644
--- a/include/tvm/relay/transform.h
+++ b/include/tvm/relay/transform.h
@@ -556,6 +556,19 @@ TVM_DLL Pass PlanDevices(CompilationConfig config);
  */
 TVM_DLL Pass FlattenAtrousConv();
 
+/*!
+ * \brief Annotates the minimum required memory of each primitive function callsite by analyzing
+ * the liveness of the input/output tensors at each function callsite and calculating the total
+ * amount of memory these tensors require. This is added as a "used_memory" annotation to the
+ * function in question as a list of the number of bytes for each callsite. In addition, the
+ * containing function is annotated with an "io_used_memory" annotation which refers to the total
+ * memory required for the IO tensors.
+ *
+ * Note: This pass does not support dynamic shapes, it is the users responsibility to check this
+ * pass isn't applied where dynamic shapes may be input.
+ */
+TVM_DLL Pass AnnotateUsedMemory();
+
 }  // namespace transform
 
 /*!
diff --git a/src/relay/backend/annotate_used_memory.cc b/src/relay/backend/annotate_used_memory.cc
new file mode 100644
index 0000000000..ad370c73ad
--- /dev/null
+++ b/src/relay/backend/annotate_used_memory.cc
@@ -0,0 +1,233 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/annotate_used_memory.cc
+ * \brief Analyzes the used memory at the callsite of primitive functions.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/transform.h>
+
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../transforms/device_aware_visitors.h"
+#include "../transforms/pass_utils.h"
+#include "./liveness_analysis.h"
+#include "./utils.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+
+/*!
+ * \brief Annotates the minimum required memory of each primitive function callsite by analyzing
+ * the liveness of the input/output tensors at each function callsite and calculating the total
+ * amount of memory these tensors require. This is added as a "used_memory" annotation to the
+ * function in question as a list of the number of bytes for each callsite. In addition, the
+ * containing function is annotated with an "io_used_memory" annotation which refers to the total
+ * memory required for the IO tensors.
+ *
+ * Note: This pass does not support dynamic shapes, it is the users responsibility to check this
+ * pass isn't applied where dynamic shapes may be input.
+ *
+ * A simple example:
+ *
+ * Before:
+ * def @main(%input: Tensor[(1, 2, 2, 4), int8]) -> Tensor[(1, 2, 2, 4), int8] {
+ *   let %x_0 = fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1) -> Tensor[(1, 2, 2, 4), int8] {
+ *     nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0])
+ *   };
+ *   let %x_1 = %x_0(%input);
+ *   %x_1
+ * }
+ *
+ * After:
+ * def @main(%input: Tensor[(1, 2, 2, 4), int8], io_used_memory=32) -> Tensor[(1, 2, 2, 4), int8] {
+ *   let %x_0: fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1, used_memory=[32]) -> Tensor[(1, 2,
+ * 2, 4), int8] {
+ *      nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0])
+ *   };
+ *   let %x_1: Tensor[(1, 2, 2, 4), int8] = %x_0(%input);
+ *   %x_1
+ * }
+ *
+ * Note that in the simple example above io_used_memory and used_memory are the same since there
+ * is only one primitive function.
+ */
+class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator {
+ public:
+  AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg,
+                            const transform::LivenessAnalysis& lva)
+      : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {}
+
+  /*!
+   * \brief Mutates the input function. In addition, an "io_used_memory" annotation is
+   * added to the input function which refers to the total size required for the IO
+   * tensors.
+   */
+  Function operator()(const Function& func) {
+    uint64_t io_used_memory = 0;
+
+    // Inputs
+    for (const Var& param : func->params) {
+      Type type = param->checked_type();
+      ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+      ICHECK(!IsDynamic(type)) << "AnnotateUsedMemory does not support dynamic shapes.";
+      io_used_memory += CalculateRelayExprSizeBytes(type);
+    }
+
+    // Outputs
+    Type type = func->body->checked_type();
+    ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+    ICHECK(!IsDynamic(type)) << "AnnotateUsedMemory does not support dynamic shapes.";
+    io_used_memory += CalculateRelayExprSizeBytes(type);
+
+    Expr new_func_body = VisitExpr(func->body);
+    Function new_func = WithFields(func, func->params, new_func_body);
+    return WithAttr(std::move(new_func), "io_used_memory",
+                    tvm::IntImm(tvm::DataType::UInt(64), io_used_memory));
+  }
+
+  /*!
+   * \brief Establish which let bindings have primitive function values.
+   */
+  std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value) {
+    if (const auto* func_node = value.as<FunctionNode>()) {
+      ICHECK(func_node->attrs.HasNonzeroAttr(attr::kPrimitive))
+          << "Expect top-level functions to be primitive.";
+      let_bound_prim_func_.insert(var);
+    }
+    return DeviceAwareExprMutator::PreVisitLetBinding_(var, value);
+  }
+
+  /*!
+   * \brief Visit let nodes and perform one of two actions depending on their value:
+   *
+   * 1. CallNode - Calculate "used_memory" annotation value at the callsite of
+   *               primitive functions.
+   *
+   * 2. FunctionNode - Annotate functions with "used_memory" annotation based on the
+   *                   previous analysis at the callsite.
+   *
+   */
+  Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node) override {
+    Var let_var = post_let_node->var;
+    Expr let_value = IgnoreOnDevice(post_let_node->value);
+
+    if (let_value->IsInstance<CallNode>()) {
+      Call callsite = Downcast<Call>(let_value);
+      if (CheckPrimitiveFunctionCall(callsite)) {
+        Var call_op = Downcast<Var>(callsite->op);
+
+        // Find all the vars that are live at the callsite. This is done by merging the
+        // in and out varset's and then removing the var that references the primitive
+        // function itself since we don't want this included in the calculation.
+        const transform::ControlFlowGraph::NodePtr cfg_node =
+            control_flow_graph_.let_map.at(GetRef<Let>(pre_let_node));
+        transform::VarSet live_tensors = liveness_.live_in.at(cfg_node);
+        const transform::VarSet& live_out = liveness_.live_out.at(cfg_node);
+        live_tensors.insert(live_out.begin(), live_out.end());
+        live_tensors.erase(call_op);
+
+        // Calculate size of live tensors and store to allow annotation when the function
+        // gets visited.
+        uint64_t used_memory = 0;
+        for (const auto& var : live_tensors) {
+          Type type = var->checked_type();
+          ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+          ICHECK(!IsDynamic(type)) << "AnnotateUsedMemory does not support dynamic shapes.";
+          used_memory += CalculateRelayExprSizeBytes(type);
+        }
+        IntImm annotation(DataType::UInt(64), used_memory);
+        used_memory_annotations_[call_op].push_back(annotation);
+      }
+    } else if (let_value->IsInstance<FunctionNode>()) {
+      Function func = Downcast<Function>(let_value);
+      ICHECK(used_memory_annotations_.find(let_var) != used_memory_annotations_.end())
+          << "Could not find used_memory value for primitive function bound at "
+          << let_var->name_hint();
+      Array<IntImm> used_memory = used_memory_annotations_[let_var];
+      used_memory_annotations_.erase(let_var);
+
+      Function new_func = WithAttr(std::move(func), "used_memory",
+                                   Array<IntImm>(used_memory.rbegin(), used_memory.rend()));
+      return Let(let_var, new_func, post_let_node->body, post_let_node->span);
+    }
+
+    return DeviceAwareExprMutator::PostVisitLet_(pre_let_node, post_let_node);
+  }
+
+ private:
+  /*!
+   * \brief Check if a call is a primitive function callsite.
+   */
+  bool CheckPrimitiveFunctionCall(const Call& callsite) {
+    if (const auto* var_node = callsite->op.as<VarNode>()) {
+      Var var = GetRef<Var>(var_node);
+      if (let_bound_prim_func_.find(var) != let_bound_prim_func_.end()) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  /*! \brief Control flow graph representation of the main function. */
+  transform::ControlFlowGraph control_flow_graph_;
+  /*! \brief Liveness analysis of the main function. */
+  transform::LivenessAnalysis liveness_;
+  /*! \brief Var's that reference primitive functions. */
+  std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> let_bound_prim_func_;
+  /*! \brief Stores the calculated uint64 used_memory values so they can be annotated on the
+   * relevant function. */
+  std::unordered_map<Var, Array<IntImm>, ObjectPtrHash, ObjectPtrEqual> used_memory_annotations_;
+};
+
+}  // namespace backend
+
+namespace transform {
+
+Pass AnnotateUsedMemory() {
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule mod,
+                                                                            PassContext ctx) {
+    GlobalVar gv = mod->GetGlobalVar("main");
+    Function main_func = Downcast<Function>(mod->Lookup("main"));
+
+    // Perform liveness analysis to determine what tensors are 'live' at each functions callsite.
+    support::Arena arena;
+    ControlFlowGraph cfg = ControlFlowGraph::Create(&arena, main_func);
+    UseDefAnalysis use_def = UseDefAnalysis::Analyze(cfg);
+    LivenessAnalysis lva = LivenessAnalysis::Analyze(cfg, use_def);
+
+    auto new_main_func = backend::AnnotateUsedMemoryMutator(mod, cfg, lva)(main_func);
+    if (!new_main_func.same_as(main_func)) {
+      mod->Update(gv, new_main_func);
+    }
+    return mod;
+  };
+  return CreateModulePass(pass_func, 0, "AnnotateUsedMemory", {"ToANormalForm", "InferType"});
+}
+
+TVM_REGISTER_GLOBAL("relay._transform.AnnotateUsedMemory").set_body_typed(AnnotateUsedMemory);
+
+}  // namespace transform
+}  // namespace relay
+}  // namespace tvm
diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc
index 5938417128..5020e79714 100644
--- a/src/relay/backend/aot_executor_codegen.cc
+++ b/src/relay/backend/aot_executor_codegen.cc
@@ -1079,6 +1079,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
     }
 
     mod = transform::ToANormalForm()(mod);
+    mod = transform::InferType()(mod);
+    mod = transform::AnnotateUsedMemory()(mod);
 
     IRModule lowered_mod =
         tec::LowerTE(mod_name, config_, [this, workspace_byte_alignment](BaseFunc func) {
diff --git a/src/relay/backend/liveness_analysis.cc b/src/relay/backend/liveness_analysis.cc
new file mode 100644
index 0000000000..52db9e6a4c
--- /dev/null
+++ b/src/relay/backend/liveness_analysis.cc
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/liveness_analysis.cc
+ * \brief  Analysis that collects the live variables before and after each node.
+ * NOTE: the input IR should be in ANF.
+ */
+
+#include "./liveness_analysis.h"
+
+#include <list>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relay {
+namespace transform {
+
+using support::Arena;
+using VarSet = std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>;
+
+ControlFlowGraph ControlFlowGraph::Create(Arena* arena, const Expr& body) {
+  return Creator().Create(arena, body);
+}
+
+ControlFlowGraph ControlFlowGraph::Creator::Create(Arena* arena, const Expr& body) {
+  arena_ = arena;
+  cfg_.entry = BasicBlock::Make(arena);
+  VisitExpr(body, cfg_.entry);
+  return std::move(cfg_);
+}
+
+void ControlFlowGraph::Creator::Succ(BasicBlockPtr from, BasicBlockPtr to) {
+  from->succ.push_back(to);
+  to->pred.push_back(from);
+}
+
+void ControlFlowGraph::Creator::VisitExpr_(const FunctionNode* f, BasicBlockPtr parent) {
+  ICHECK(!in_func_) << "nested functions not supported by CFG analysis";
+  in_func_ = true;
+
+  // Unwrap the nested function and proceed normally.
+  if (f->HasNonzeroAttr(attr::kClosure)) {
+    ICHECK(f->body.as<FunctionNode>());
+    return VisitExpr(Downcast<Function>(f->body)->body, parent);
+  }
+
+  return VisitExpr(f->body, parent);
+}
+
+void ControlFlowGraph::Creator::VisitExpr_(const LetNode* let_node, BasicBlockPtr parent) {
+  Expr expr = GetRef<Expr>(let_node);
+
+  while (const LetNode* inner_let_node = expr.as<LetNode>()) {
+    NodePtr curr_node = Node::Make(arena_, parent, expr);
+
+    ICHECK(!cfg_.let_map.count(expr));
+    cfg_.let_map[expr] = curr_node;
+    cfg_.reverse_post_order.push_back(curr_node);
+
+    // The basic block ends upon reaching control flow, with successor blocks corresponding to the
+    // control flow branch exprs (true/false in If, and one for each clause in Match).
+    if (const IfNode* ite = AsIgnoringOnDevice<IfNode>(inner_let_node->value)) {
+      // Create the basic blocks for each branch and mark them as successors to the current block.
+      BasicBlockPtr t_block = BasicBlock::Make(arena_);
+      BasicBlockPtr f_block = BasicBlock::Make(arena_);
+      Succ(parent, t_block);
+      Succ(parent, f_block);
+
+      VisitExpr(ite->true_branch, t_block);
+      VisitExpr(ite->false_branch, f_block);
+
+      // All subsequent bindings (and/or the body expr) will be in a new basic block.
+      BasicBlockPtr next = BasicBlock::Make(arena_);
+      Succ(t_block, next);
+      Succ(f_block, next);
+      parent = next;
+    } else if (const MatchNode* match = AsIgnoringOnDevice<MatchNode>(inner_let_node->value)) {
+      // Same as above but one for each pattern.
+      std::vector<BasicBlockPtr> clause_blocks;
+      BasicBlockPtr next = BasicBlock::Make(arena_);
+      for (const Clause& clause : match->clauses) {
+        BasicBlockPtr clause_block = BasicBlock::Make(arena_);
+        Succ(parent, clause_block);
+        Succ(clause_block, next);
+        VisitExpr(clause->rhs, clause_block);
+      }
+      parent = next;
+    }
+
+    expr = inner_let_node->body;
+  }
+
+  VisitExpr(expr, parent);
+}
+
+void ControlFlowGraph::Creator::VisitExpr_(const IfNode* if_node, BasicBlockPtr parent) {
+  // TODO(@altanh): is there a way of making this work?
+  LOG(FATAL) << "If expressions should be bound to variables.";
+}
+
+void ControlFlowGraph::Creator::VisitExpr_(const MatchNode* match_node, BasicBlockPtr parent) {
+  // TODO(@altanh): same as If
+  LOG(FATAL) << "Match expressions should be bound to variables.";
+}
+
+VarSet VarUseCollector::VisitExpr_(const VarNode* var_node) { return {GetRef<Var>(var_node)}; }
+
+VarSet VarUseCollector::VisitExpr_(const CallNode* call_node) {
+  VarSet use = VisitExpr(call_node->op);
+  for (const Expr& arg : call_node->args) {
+    VarSet arg_use = VisitExpr(arg);
+    use.insert(arg_use.begin(), arg_use.end());
+  }
+  return use;
+}
+
+VarSet VarUseCollector::VisitExpr_(const TupleNode* tuple_node) {
+  VarSet use;
+  for (const Expr& field : tuple_node->fields) {
+    VarSet field_use = VisitExpr(field);
+    use.insert(field_use.begin(), field_use.end());
+  }
+  return use;
+}
+
+VarSet VarUseCollector::VisitExpr_(const TupleGetItemNode* get_node) {
+  return VisitExpr(get_node->tuple);
+}
+
+VarSet VarUseCollector::VisitExpr_(const IfNode* if_node) { return VisitExpr(if_node->cond); }
+
+VarSet VarUseCollector::VisitExpr_(const MatchNode* match_node) {
+  return VisitExpr(match_node->data);
+}
+
+UseDefAnalysis UseDefAnalysis::Analyze(const CFG& cfg) {
+  UseDefAnalysis a;
+
+  // One pass is sufficient.
+  for (auto it = cfg.reverse_post_order.begin(); it != cfg.reverse_post_order.end(); ++it) {
+    const CFG::NodePtr& node = *it;
+    if (const LetNode* let_node = AsIgnoringOnDevice<LetNode>(node->expr)) {
+      a.use[node] = a.use_collector.VisitExpr(let_node->value);
+      a.def[node] = let_node->var;
+    } else {
+      a.use[node] = a.use_collector.VisitExpr(node->expr);
+      a.def[node] = Var();
+    }
+  }
+
+  return a;
+}
+
+bool SetEqual(const VarSet& a, const VarSet& b) {
+  if (a.size() != b.size()) {
+    return false;
+  }
+  for (auto& xa : a) {
+    if (!b.count(xa)) {
+      return false;
+    }
+  }
+  return true;
+}
+
+LivenessAnalysis LivenessAnalysis::Analyze(const ControlFlowGraph& cfg,
+                                           const UseDefAnalysis& use_def) {
+  LivenessAnalysis a;
+  std::list<CFG::NodePtr> worklist;
+
+  // Initialize worklist to post-order traversal for quick convergence.
+  worklist.insert(worklist.end(), cfg.reverse_post_order.rbegin(), cfg.reverse_post_order.rend());
+
+  // See https://lambda.uta.edu/cse5317/notes/node40.html for an overview of the algorithm.
+  auto visitor = [&](const CFG::NodePtr n) {
+    VarSet old_in_n = a.live_in[n];
+    VarSet old_out_n = a.live_out[n];
+
+    a.live_in[n] = use_def.use.at(n);
+    for (const Var& v : a.live_out[n]) {
+      if (!v.same_as(use_def.def.at(n))) {
+        a.live_in[n].insert(v);
+      }
+    }
+
+    a.live_out[n] = VarSet();
+    for (const CFG::NodePtr& s : n->GetSucc()) {
+      a.live_out[n].insert(a.live_in[s].begin(), a.live_in[s].end());
+    }
+
+    if (SetEqual(old_in_n, a.live_in[n]) && SetEqual(old_out_n, a.live_out[n])) {
+      // No need to update the worklist.
+    } else {
+      // Add predecessor nodes back to worklist (no need to add successors, since each node's
+      // in/out sets are not dependent on its predecessors).
+      for (const CFG::NodePtr& p : n->GetPred()) {
+        worklist.push_back(p);
+      }
+    }
+  };
+
+  while (!worklist.empty()) {
+    const CFG::NodePtr n = worklist.front();
+    worklist.pop_front();
+    visitor(n);
+  }
+
+  return a;
+}
+
+}  // namespace transform
+}  // namespace relay
+}  // namespace tvm
diff --git a/src/relay/backend/liveness_analysis.h b/src/relay/backend/liveness_analysis.h
new file mode 100644
index 0000000000..4e9514056b
--- /dev/null
+++ b/src/relay/backend/liveness_analysis.h
@@ -0,0 +1,270 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/liveness_analysis.h
+ * \brief  Analysis that collects the live variables before and after each node.
+ * NOTE: the input IR should be in ANF.
+ */
+
+#ifndef TVM_RELAY_BACKEND_LIVENESS_ANALYSIS_H_
+#define TVM_RELAY_BACKEND_LIVENESS_ANALYSIS_H_
+
+#include <tvm/relay/transform.h>
+
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "../../support/arena.h"
+#include "../op/memory/device_copy.h"
+#include "../transforms/device_aware_visitors.h"
+#include "../transforms/let_list.h"
+
+namespace tvm {
+namespace relay {
+namespace transform {
+
+using support::Arena;
+using VarSet = std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>;
+
+// TODO(@altanh, @mbs, @mbrookhart): we should do a survey of all "*-flow graphs" in the codebase
+//                                   to see what can be deduplicated.
+
+// TODO(@altanh): support Relay Refs once/if they are supported by the VM.
+
+/*!
+ * \brief A representation of an input expression (typically a Function) as a directed graph of
+ * basic blocks, with edges between basic blocks corresponding to control flow branching.
+ */
+class ControlFlowGraph {
+ public:
+  struct Node;
+  struct BasicBlock;
+
+  using NodePtr = Node*;
+  using BasicBlockPtr = BasicBlock*;
+
+  /*!
+   * \brief A chunk of IR that does not have any control flow branching. At this stage in the IR,
+   * basic blocks correspond to:
+   *   (1) a sequence of nested Let expressions, where each node in the block corresponds to a
+   *       binding and the last node is either the (non-Let) body or a binding that branches
+   *       (e.g. "let %x = if (%c) { true_block } else { false_block }").
+   *   (2) an atomic expression representing the target expression of a control flow branch, e.g.
+   *       %v and %u in "let %x = if (%c) { %v } else { %u }".
+   */
+  struct BasicBlock {
+    // The nodes of the basic block.
+    std::vector<NodePtr> nodes;
+    // The predecessor basic blocks.
+    std::vector<BasicBlockPtr> pred;
+    // The successor basic blocks.
+    std::vector<BasicBlockPtr> succ;
+
+    static BasicBlockPtr Make(support::Arena* arena) { return arena->make<BasicBlock>(); }
+  };
+
+  /*!
+   * \brief Roughly corresponds to a "statement" in the IR, such as an individual binding in a
+   * basic block or the "return value" of a block. Each node maps to a single corresponding expr in
+   * the IR, but the converse is not true (e.g. in the case of variables).
+   */
+  struct Node {
+    /*! \brief The basic block this node belongs to. */
+    BasicBlockPtr parent;
+    /*! \brief The index into the parent basic block where this node is. */
+    size_t index;
+    /*! \brief The expr this node corresponds to. */
+    Expr expr;
+
+    /*! \brief Returns whether or not this node is the first one in the parent basic block. */
+    bool IsFirst() const { return index == 0; }
+
+    /*! \brief Returns whether or not this node is the last one in the parent basic block. */
+    bool IsLast() const { return index == parent->nodes.size() - 1; }
+
+    /*! \brief Returns the predecessor nodes of this node. */
+    std::vector<NodePtr> GetPred() const {
+      std::vector<NodePtr> pred;
+      if (IsFirst()) {
+        for (const BasicBlockPtr& pred_block : parent->pred) {
+          pred.push_back(pred_block->nodes.back());
+        }
+      } else {
+        pred.push_back(parent->nodes[index - 1]);
+      }
+      return pred;
+    }
+
+    /*! \brief Returns the successor nodes of this node. */
+    std::vector<NodePtr> GetSucc() const {
+      std::vector<NodePtr> succ;
+      if (IsLast()) {
+        for (const BasicBlockPtr& succ_block : parent->succ) {
+          succ.push_back(succ_block->nodes.front());
+        }
+      } else {
+        succ.push_back(parent->nodes[index + 1]);
+      }
+      return succ;
+    }
+
+    /*! \brief Creates a node with the given expr and appends it to the parent basic block. */
+    static NodePtr Make(Arena* arena, BasicBlockPtr parent, Expr expr) {
+      NodePtr n = arena->make<Node>();
+      n->parent = parent;
+      n->expr = expr;
+      n->index = parent->nodes.size();
+      parent->nodes.push_back(n);
+      return n;
+    }
+  };
+
+  /*! \brief The basic block where control flow begins. */
+  BasicBlockPtr entry;
+
+  /*!
+   * \brief Mapping from Let expressions to their corresponding nodes. Note that Let expressions
+   * are never shared in ANF (unlike vars), so this is an injection.
+   */
+  std::unordered_map<Expr, NodePtr, ObjectPtrHash, ObjectPtrEqual> let_map;
+
+  /*! \brief The nodes of the CFG in reverse post order. */
+  std::vector<NodePtr> reverse_post_order;
+
+  /*! \brief Creates and returns the CFG of the given expression. */
+  static ControlFlowGraph Create(Arena* arena, const Expr& body);
+
+ private:
+  class Creator;
+};
+
+/*! \brief Helper class for building CFGs. */
+class ControlFlowGraph::Creator : private ExprFunctor<void(const Expr&, BasicBlockPtr)> {
+ public:
+  Creator() {}
+
+  ControlFlowGraph Create(Arena* arena, const Expr& body);
+
+ private:
+  /*! \brief The arena allocator. */
+  Arena* arena_;
+
+  /*! \brief The CFG being built. */
+  ControlFlowGraph cfg_;
+  /*!
+   * \brief Whether or not we are in a function. CFGs do not support nested functions so this is
+   * used to error out in such a case.
+   */
+  bool in_func_ = false;
+
+  /*!
+   * \brief Link \p to as a successor block to \p from.
+   */
+  void Succ(BasicBlockPtr from, BasicBlockPtr to);
+
+#define DEFAULT_CFG(OP)                                       \
+  void VisitExpr_(const OP* op, BasicBlockPtr parent) final { \
+    NodePtr n = Node::Make(arena_, parent, GetRef<Expr>(op)); \
+    cfg_.reverse_post_order.push_back(n);                     \
+  }
+
+  void VisitExpr_(const FunctionNode* f, BasicBlockPtr parent) final;
+  void VisitExpr_(const LetNode* let_node, BasicBlockPtr parent) final;
+  void VisitExpr_(const IfNode* if_node, BasicBlockPtr parent);
+  void VisitExpr_(const MatchNode* match_node, BasicBlockPtr parent);
+
+  DEFAULT_CFG(VarNode);
+  DEFAULT_CFG(GlobalVarNode);
+  DEFAULT_CFG(ConstantNode);
+  DEFAULT_CFG(CallNode);
+  DEFAULT_CFG(OpNode);
+  DEFAULT_CFG(TupleNode);
+  DEFAULT_CFG(TupleGetItemNode);
+};
+
+/*!
+ * \brief Helper class for collecting the variables used/read by an expression. NOTE: for If exprs,
+ * only the condition is included (not the branches). Similarly, for Match exprs only the value
+ * being deconstructed is included.
+ */
+class VarUseCollector : public ExprFunctor<VarSet(const Expr& e)> {
+ public:
+  VarSet VisitExpr_(const VarNode* var_node);
+  VarSet VisitExpr_(const CallNode* call_node);
+  VarSet VisitExpr_(const TupleNode* tuple_node);
+  VarSet VisitExpr_(const TupleGetItemNode* get_node);
+  VarSet VisitExpr_(const IfNode* if_node);
+  VarSet VisitExpr_(const MatchNode* match_node);
+
+  VarSet VisitExpr_(const ConstructorNode* cons_node) { return {}; }
+  VarSet VisitExpr_(const GlobalVarNode* gvar_node) { return {}; }
+  VarSet VisitExpr_(const ConstantNode* const_node) { return {}; }
+  VarSet VisitExpr_(const OpNode* op_node) { return {}; }
+  VarSet VisitExpr_(const FunctionNode* func_node) { return {}; }
+};
+
+/*!
+ * \brief Analysis that collects the variables used and defined at each node.
+ */
+struct UseDefAnalysis {
+  using CFG = ControlFlowGraph;
+
+  /*! \brief Mapping of node -> variables used/read by node. */
+  std::unordered_map<CFG::NodePtr, VarSet> use;
+
+  /*! \brief Mapping of node -> variable defined/written by node. */
+  std::unordered_map<CFG::NodePtr, Var> def;
+
+  VarUseCollector use_collector;
+
+  static UseDefAnalysis Analyze(const CFG& cfg);
+};
+
+/*! \brief Returns whether \p a and \p b are the same set of vars. */
+bool SetEqual(const VarSet& a, const VarSet& b);
+
+/*!
+ * \brief Analysis that collects the live variables before and after each node.
+ */
+struct LivenessAnalysis {
+  using CFG = ControlFlowGraph;
+
+  /*! \brief Mapping of node -> set of variables live before node. */
+  std::unordered_map<CFG::NodePtr, VarSet> live_in;
+
+  /*! \brief Mapping of node -> set of variables live after node. */
+  std::unordered_map<CFG::NodePtr, VarSet> live_out;
+
+  /*!
+   * \brief Analyze the input \p cfg (using info from \p use_def).
+   *
+   * \param cfg The input control flow graph.
+   * \param use_def Use-def analysis of \p cfg.
+   * \return LivenessAnalysis
+   */
+  static LivenessAnalysis Analyze(const ControlFlowGraph& cfg, const UseDefAnalysis& use_def);
+};
+
+}  // namespace transform
+}  // namespace relay
+}  // namespace tvm
+
+#endif  // TVM_RELAY_BACKEND_LIVENESS_ANALYSIS_H_
diff --git a/src/relay/backend/vm/manifest_lifetimes.cc b/src/relay/backend/vm/manifest_lifetimes.cc
index 3ba129702b..486e063203 100644
--- a/src/relay/backend/vm/manifest_lifetimes.cc
+++ b/src/relay/backend/vm/manifest_lifetimes.cc
@@ -29,398 +29,12 @@
 #include "../../op/memory/device_copy.h"
 #include "../../transforms/device_aware_visitors.h"
 #include "../../transforms/let_list.h"
+#include "../liveness_analysis.h"
 
 namespace tvm {
 namespace relay {
 namespace transform {
 
-using support::Arena;
-using VarSet = std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>;
-
-// TODO(@altanh, @mbs, @mbrookhart): we should do a survey of all "*-flow graphs" in the codebase
-//                                   to see what can be deduplicated.
-
-// TODO(@altanh): support Relay Refs once/if they are supported by the VM.
-
-/*!
- * \brief A representation of an input expression (typically a Function) as a directed graph of
- * basic blocks, with edges between basic blocks corresponding to control flow branching.
- */
-class ControlFlowGraph {
- public:
-  struct Node;
-  struct BasicBlock;
-
-  using NodePtr = Node*;
-  using BasicBlockPtr = BasicBlock*;
-
-  /*!
-   * \brief A chunk of IR that does not have any control flow branching. At this stage in the IR,
-   * basic blocks correspond to:
-   *   (1) a sequence of nested Let expressions, where each node in the block corresponds to a
-   *       binding and the last node is either the (non-Let) body or a binding that branches
-   *       (e.g. "let %x = if (%c) { true_block } else { false_block }").
-   *   (2) an atomic expression representing the target expression of a control flow branch, e.g.
-   *       %v and %u in "let %x = if (%c) { %v } else { %u }".
-   */
-  struct BasicBlock {
-    // The nodes of the basic block.
-    std::vector<NodePtr> nodes;
-    // The predecessor basic blocks.
-    std::vector<BasicBlockPtr> pred;
-    // The successor basic blocks.
-    std::vector<BasicBlockPtr> succ;
-
-    static BasicBlockPtr Make(Arena* arena) { return arena->make<BasicBlock>(); }
-  };
-
-  /*!
-   * \brief Roughly corresponds to a "statement" in the IR, such as an individual binding in a
-   * basic block or the "return value" of a block. Each node maps to a single corresponding expr in
-   * the IR, but the converse is not true (e.g. in the case of variables).
-   */
-  struct Node {
-    /*! \brief The basic block this node belongs to. */
-    BasicBlockPtr parent;
-    /*! \brief The index into the parent basic block where this node is. */
-    size_t index;
-    /*! \brief The expr this node corresponds to. */
-    Expr expr;
-
-    /*! \brief Returns whether or not this node is the first one in the parent basic block. */
-    bool IsFirst() const { return index == 0; }
-
-    /*! \brief Returns whether or not this node is the last one in the parent basic block. */
-    bool IsLast() const { return index == parent->nodes.size() - 1; }
-
-    /*! \brief Returns the predecessor nodes of this node. */
-    std::vector<NodePtr> GetPred() const {
-      std::vector<NodePtr> pred;
-      if (IsFirst()) {
-        for (const BasicBlockPtr& pred_block : parent->pred) {
-          pred.push_back(pred_block->nodes.back());
-        }
-      } else {
-        pred.push_back(parent->nodes[index - 1]);
-      }
-      return pred;
-    }
-
-    /*! \brief Returns the successor nodes of this node. */
-    std::vector<NodePtr> GetSucc() const {
-      std::vector<NodePtr> succ;
-      if (IsLast()) {
-        for (const BasicBlockPtr& succ_block : parent->succ) {
-          succ.push_back(succ_block->nodes.front());
-        }
-      } else {
-        succ.push_back(parent->nodes[index + 1]);
-      }
-      return succ;
-    }
-
-    /*! \brief Creates a node with the given expr and appends it to the parent basic block. */
-    static NodePtr Make(Arena* arena, BasicBlockPtr parent, Expr expr) {
-      NodePtr n = arena->make<Node>();
-      n->parent = parent;
-      n->expr = expr;
-      n->index = parent->nodes.size();
-      parent->nodes.push_back(n);
-      return n;
-    }
-  };
-
-  /*! \brief The basic block where control flow begins. */
-  BasicBlockPtr entry;
-
-  /*!
-   * \brief Mapping from Let expressions to their corresponding nodes. Note that Let expressions
-   * are never shared in ANF (unlike vars), so this is an injection.
-   */
-  std::unordered_map<Expr, NodePtr, ObjectPtrHash, ObjectPtrEqual> let_map;
-
-  /*! \brief The nodes of the CFG in reverse post order. */
-  std::vector<NodePtr> reverse_post_order;
-
-  /*! \brief Creates and returns the CFG of the given expression. */
-  static ControlFlowGraph Create(Arena* arena, const Expr& body);
-
- private:
-  class Creator;
-};
-
-/*! \brief Helper class for building CFGs. */
-class ControlFlowGraph::Creator : private ExprFunctor<void(const Expr&, BasicBlockPtr)> {
- public:
-  Creator() {}
-
-  ControlFlowGraph Create(Arena* arena, const Expr& body) {
-    arena_ = arena;
-    cfg_.entry = BasicBlock::Make(arena);
-    VisitExpr(body, cfg_.entry);
-    return std::move(cfg_);
-  }
-
- private:
-  /*! \brief The arena allocator. */
-  Arena* arena_;
-
-  /*! \brief The CFG being built. */
-  ControlFlowGraph cfg_;
-  /*!
-   * \brief Whether or not we are in a function. CFGs do not support nested functions so this is
-   * used to error out in such a case.
-   */
-  bool in_func_ = false;
-
-  /*!
-   * \brief Link \p to as a successor block to \p from.
-   */
-  void Succ(BasicBlockPtr from, BasicBlockPtr to) {
-    from->succ.push_back(to);
-    to->pred.push_back(from);
-  }
-
-#define DEFAULT_CFG(OP)                                       \
-  void VisitExpr_(const OP* op, BasicBlockPtr parent) final { \
-    NodePtr n = Node::Make(arena_, parent, GetRef<Expr>(op)); \
-    cfg_.reverse_post_order.push_back(n);                     \
-  }
-
-  void VisitExpr_(const FunctionNode* f, BasicBlockPtr parent) final {
-    ICHECK(!in_func_) << "nested functions not supported by CFG analysis";
-    in_func_ = true;
-
-    // Unwrap the nested function and proceed normally.
-    if (f->HasNonzeroAttr(attr::kClosure)) {
-      ICHECK(f->body.as<FunctionNode>());
-      return VisitExpr(Downcast<Function>(f->body)->body, parent);
-    }
-
-    return VisitExpr(f->body, parent);
-  }
-
-  void VisitExpr_(const LetNode* let_node, BasicBlockPtr parent) final {
-    Expr expr = GetRef<Expr>(let_node);
-
-    while (const LetNode* inner_let_node = expr.as<LetNode>()) {
-      NodePtr curr_node = Node::Make(arena_, parent, expr);
-
-      ICHECK(!cfg_.let_map.count(expr));
-      cfg_.let_map[expr] = curr_node;
-      cfg_.reverse_post_order.push_back(curr_node);
-
-      // The basic block ends upon reaching control flow, with successor blocks corresponding to the
-      // control flow branch exprs (true/false in If, and one for each clause in Match).
-      if (const IfNode* ite = AsIgnoringOnDevice<IfNode>(inner_let_node->value)) {
-        // Create the basic blocks for each branch and mark them as successors to the current block.
-        BasicBlockPtr t_block = BasicBlock::Make(arena_);
-        BasicBlockPtr f_block = BasicBlock::Make(arena_);
-        Succ(parent, t_block);
-        Succ(parent, f_block);
-
-        VisitExpr(ite->true_branch, t_block);
-        VisitExpr(ite->false_branch, f_block);
-
-        // All subsequent bindings (and/or the body expr) will be in a new basic block.
-        BasicBlockPtr next = BasicBlock::Make(arena_);
-        Succ(t_block, next);
-        Succ(f_block, next);
-        parent = next;
-      } else if (const MatchNode* match = AsIgnoringOnDevice<MatchNode>(inner_let_node->value)) {
-        // Same as above but one for each pattern.
-        std::vector<BasicBlockPtr> clause_blocks;
-        BasicBlockPtr next = BasicBlock::Make(arena_);
-        for (const Clause& clause : match->clauses) {
-          BasicBlockPtr clause_block = BasicBlock::Make(arena_);
-          Succ(parent, clause_block);
-          Succ(clause_block, next);
-          VisitExpr(clause->rhs, clause_block);
-        }
-        parent = next;
-      }
-
-      expr = inner_let_node->body;
-    }
-
-    VisitExpr(expr, parent);
-  }
-
-  void VisitExpr_(const IfNode* if_node, BasicBlockPtr parent) {
-    // TODO(@altanh): is there a way of making this work?
-    LOG(FATAL) << "If expressions should be bound to variables.";
-  }
-
-  void VisitExpr_(const MatchNode* match_node, BasicBlockPtr parent) {
-    // TODO(@altanh): same as If
-    LOG(FATAL) << "Match expressions should be bound to variables.";
-  }
-
-  DEFAULT_CFG(VarNode);
-  DEFAULT_CFG(GlobalVarNode);
-  DEFAULT_CFG(ConstantNode);
-  DEFAULT_CFG(CallNode);
-  DEFAULT_CFG(OpNode);
-  DEFAULT_CFG(TupleNode);
-  DEFAULT_CFG(TupleGetItemNode);
-};
-
-ControlFlowGraph ControlFlowGraph::Create(Arena* arena, const Expr& body) {
-  return Creator().Create(arena, body);
-}
-
-/*!
- * \brief Helper class for collecting the variables used/read by an expression. NOTE: for If exprs,
- * only the condition is included (not the branches). Similarly, for Match exprs only the value
- * being deconstructed is included.
- */
-class VarUseCollector : public ExprFunctor<VarSet(const Expr& e)> {
- public:
-  VarSet VisitExpr_(const VarNode* var_node) { return {GetRef<Var>(var_node)}; }
-
-  VarSet VisitExpr_(const CallNode* call_node) {
-    VarSet use = VisitExpr(call_node->op);
-    for (const Expr& arg : call_node->args) {
-      VarSet arg_use = VisitExpr(arg);
-      use.insert(arg_use.begin(), arg_use.end());
-    }
-    return use;
-  }
-
-  VarSet VisitExpr_(const TupleNode* tuple_node) {
-    VarSet use;
-    for (const Expr& field : tuple_node->fields) {
-      VarSet field_use = VisitExpr(field);
-      use.insert(field_use.begin(), field_use.end());
-    }
-    return use;
-  }
-
-  VarSet VisitExpr_(const TupleGetItemNode* get_node) { return VisitExpr(get_node->tuple); }
-
-  VarSet VisitExpr_(const IfNode* if_node) { return VisitExpr(if_node->cond); }
-
-  VarSet VisitExpr_(const MatchNode* match_node) { return VisitExpr(match_node->data); }
-
-  VarSet VisitExpr_(const ConstructorNode* cons_node) { return {}; }
-
-  VarSet VisitExpr_(const GlobalVarNode* gvar_node) { return {}; }
-
-  VarSet VisitExpr_(const ConstantNode* const_node) { return {}; }
-
-  VarSet VisitExpr_(const OpNode* op_node) { return {}; }
-};
-
-/*!
- * \brief Analysis that collects the variables used and defined at each node.
- */
-struct UseDefAnalysis {
-  using CFG = ControlFlowGraph;
-
-  /*! \brief Mapping of node -> variables used/read by node. */
-  std::unordered_map<CFG::NodePtr, VarSet> use;
-
-  /*! \brief Mapping of node -> variable defined/written by node. */
-  std::unordered_map<CFG::NodePtr, Var> def;
-
-  VarUseCollector use_collector;
-
-  static UseDefAnalysis Analyze(const CFG& cfg) {
-    UseDefAnalysis a;
-
-    // One pass is sufficient.
-    for (auto it = cfg.reverse_post_order.begin(); it != cfg.reverse_post_order.end(); ++it) {
-      const CFG::NodePtr& node = *it;
-      if (const LetNode* let_node = AsIgnoringOnDevice<LetNode>(node->expr)) {
-        a.use[node] = a.use_collector.VisitExpr(let_node->value);
-        a.def[node] = let_node->var;
-      } else {
-        a.use[node] = a.use_collector.VisitExpr(node->expr);
-        a.def[node] = Var();
-      }
-    }
-
-    return a;
-  }
-};
-
-/*! \brief Returns whether \p a and \p b are the same set of vars. */
-bool SetEqual(const VarSet& a, const VarSet& b) {
-  if (a.size() != b.size()) {
-    return false;
-  }
-  for (auto& xa : a) {
-    if (!b.count(xa)) {
-      return false;
-    }
-  }
-  return true;
-}
-
-/*!
- * \brief Analysis that collects the live variables before and after each node.
- */
-struct LivenessAnalysis {
-  using CFG = ControlFlowGraph;
-
-  /*! \brief Mapping of node -> set of variables live before node. */
-  std::unordered_map<CFG::NodePtr, VarSet> live_in;
-
-  /*! \brief Mapping of node -> set of variables live after node. */
-  std::unordered_map<CFG::NodePtr, VarSet> live_out;
-
-  /*!
-   * \brief Analyze the input \p cfg (using info from \p use_def).
-   *
-   * \param cfg The input control flow graph.
-   * \param use_def Use-def analysis of \p cfg.
-   * \return LivenessAnalysis
-   */
-  static LivenessAnalysis Analyze(const ControlFlowGraph& cfg, const UseDefAnalysis& use_def) {
-    LivenessAnalysis a;
-    std::list<CFG::NodePtr> worklist;
-
-    // Initialize worklist to post-order traversal for quick convergence.
-    worklist.insert(worklist.end(), cfg.reverse_post_order.rbegin(), cfg.reverse_post_order.rend());
-
-    // See https://lambda.uta.edu/cse5317/notes/node40.html for an overview of the algorithm.
-    auto visitor = [&](const CFG::NodePtr n) {
-      VarSet old_in_n = a.live_in[n];
-      VarSet old_out_n = a.live_out[n];
-
-      a.live_in[n] = use_def.use.at(n);
-      for (const Var& v : a.live_out[n]) {
-        if (!v.same_as(use_def.def.at(n))) {
-          a.live_in[n].insert(v);
-        }
-      }
-
-      a.live_out[n] = VarSet();
-      for (const CFG::NodePtr& s : n->GetSucc()) {
-        a.live_out[n].insert(a.live_in[s].begin(), a.live_in[s].end());
-      }
-
-      if (SetEqual(old_in_n, a.live_in[n]) && SetEqual(old_out_n, a.live_out[n])) {
-        // No need to update the worklist.
-      } else {
-        // Add predecessor nodes back to worklist (no need to add successors, since each node's
-        // in/out sets are not dependent on its predecessors).
-        for (const CFG::NodePtr& p : n->GetPred()) {
-          worklist.push_back(p);
-        }
-      }
-    };
-
-    while (!worklist.empty()) {
-      const CFG::NodePtr n = worklist.front();
-      worklist.pop_front();
-      visitor(n);
-    }
-
-    return a;
-  }
-};
-
 /*!
  * \brief Helper class to insert kills using liveness information.
  */
diff --git a/tests/python/relay/test_used_memory_annotator.py b/tests/python/relay/test_used_memory_annotator.py
new file mode 100644
index 0000000000..e339152294
--- /dev/null
+++ b/tests/python/relay/test_used_memory_annotator.py
@@ -0,0 +1,434 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+
+"""
+Testing for the pass that annotates used memory for each primitive
+Relay function.
+"""
+
+import pytest
+
+import tvm
+from tvm import relay
+from tvm.relay.expr_functor import ExprVisitor
+
+
+def AnnotateUsedMemory():
+    return relay.transform._ffi_api.AnnotateUsedMemory()
+
+
+class CheckUsedMemoryAnnotation(ExprVisitor):
+    """
+    Check that the annotations on each function in the graph match
+    what is expected.
+    """
+
+    def __init__(self, expected_annotations, expected_io_annotation):
+        self.expected_annotations = expected_annotations
+        self.expected_io_annotation = expected_io_annotation
+        super().__init__()
+
+    def visit_function(self, fn):
+        if "Primitive" in fn.attrs:
+            assert (
+                "used_memory" in fn.attrs
+            ), "Primitive function does not have used_memory annotation."
+
+            assert len(self.expected_annotations) > 0, "Not all expected annotations were compared"
+
+            expected_mem = self.expected_annotations.pop(0)
+            actual_mem = [int(x) for x in fn.attrs["used_memory"]]
+            assert expected_mem == actual_mem, (
+                f"Expected used memory annotation {expected_mem} "
+                f"did not match actual annotation {actual_mem}"
+            )
+        super().visit_function(fn)
+
+    def __call__(self, fn):
+        assert (
+            fn.attrs["io_used_memory"] == self.expected_io_annotation
+        ), "Expected IO annotation did not match."
+        self.visit(fn.body)
+
+
+def _check_used_memory_annotations(mod, expected_annotations, expected_io_annotation):
+    mod = relay.transform.InferType()(mod)
+    mod = relay.transform.ToANormalForm()(mod)
+    mod = relay.transform.InferType()(mod)
+    mod = AnnotateUsedMemory()(mod)
+
+    CheckUsedMemoryAnnotation(expected_annotations, expected_io_annotation)(mod["main"])
+
+
+def _create_primitive_function(expr):
+    func = relay.Function(relay.analysis.free_vars(expr), expr)
+    func = func.with_attr("Primitive", 1)
+    return func
+
+
+def test_simple():
+    """
+    Test simple graph with one primitive function.
+    """
+
+    def get_inner_func():
+        x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8")
+        x = relay.nn.max_pool2d(x)
+        x = _create_primitive_function(x)
+        return x
+
+    ifm = relay.var("input", shape=(1, 2, 2, 4), dtype="int8")
+    call = relay.Call(get_inner_func(), [ifm])
+    mod = tvm.IRModule.from_expr(call)
+
+    expected_annotations = [
+        [2 * (1 * 2 * 2 * 4)],
+    ]
+    expected_io_annotation = 2 * (1 * 2 * 2 * 4)
+    _check_used_memory_annotations(mod, expected_annotations, expected_io_annotation)
+
+
+def test_multiple_functions():
+    """
+    Test a graph with multiple primitive functions.
+    """
+
+    def get_inner_func(ifm_shape):
+        x = relay.var("x", shape=ifm_shape, dtype="int8")
+        x = relay.nn.max_pool2d(x, pool_size=(2, 2), layout="NHWC")
+        x = _create_primitive_function(x)
+        return x
+
+    ifm = relay.var("input", shape=(1, 8, 8, 2), dtype="int8")
+    x = get_inner_func((1, 8, 8, 2))
+    x = relay.Call(x, [ifm])
+    y = get_inner_func((1, 7, 7, 2))
+    y = relay.Call(y, [x])
+    z = get_inner_func((1, 6, 6, 2))
+    z = relay.Call(z, [y])
+    mod = tvm.IRModule.from_expr(z)
+
+    expected_annotations = [
+        [(1 * 8 * 8 * 2) + (1 * 7 * 7 * 2)],
+        [(1 * 7 * 7 * 2) + (1 * 6 * 6 * 2)],
+        [(1 * 6 * 6 * 2) + (1 * 5 * 5 * 2)],
+    ]
+    expected_io_annotation = (1 * 8 * 8 * 2) + (1 * 5 * 5 * 2)
+    _check_used_memory_annotations(mod, expected_annotations, expected_io_annotation)
+
+
+def test_mixed_data_types():
+    """
+    Test a graph with a primitive function that has mixed datatypes.
+    """
+
+    def get_inner_func():
+        x = relay.var("x", shape=(1, 2, 2, 2), dtype="int16")
+        x = relay.cast(x, dtype="uint32")
+        x = _create_primitive_function(x)
+        return x
+
+    ifm = relay.var("input", shape=(1, 2, 2, 2), dtype="int16")
+    x = get_inner_func()
+    x = relay.Call(x, [ifm])
+    mod = tvm.IRModule.from_expr(x)
+
+    expected_annotations = [
+        [(1 * 2 * 2 * 2) * 2 + (1 * 2 * 2 * 2) * 4],
+    ]
+    expected_io_annotation = (1 * 2 * 2 * 2) * 2 + (1 * 2 * 2 * 2) * 4
+    _check_used_memory_annotations(mod, expected_annotations, expected_io_annotation)
+
+
+def test_parallel_function_call():
+    """
+    Test a graph when the results of two functions are concatenated
+    into a single result. The second function will also have the result
+    of the first function alive so will be annotated with a larger
+    "used memory" value.
+    """
+
+    def get_inner_func():
+        x = relay.var("x", shape=(1, 4, 5, 6), dtype="int8")
+        x = relay.reshape(x, newshape=(1, 4, 30))
+        x = _create_primitive_function(x)
+        return x
+
+    ifm = relay.var("input", shape=(1, 4, 5, 6), dtype="int8")
+    x = relay.Call(get_inner_func(), [ifm])
+    y = relay.Call(get_inner_func(), [ifm])
+    z = relay.concatenate([x, y], axis=0)
+    mod = tvm.IRModule.from_expr(z)
+
+    expected_annotations = [
+        [(1 * 4 * 5 * 6) + (1 * 4 * 30)],
+        # the output tensor from the previous function is also alive
+        [(1 * 4 * 5 * 6) + (1 * 4 * 30) + (1 * 4 * 30)],
+    ]
+    expected_io_annotation = (1 * 4 * 5 * 6) + (1 * 4 * 60)
+    _check_used_memory_annotations(mod, expected_annotations, expected_io_annotation)
+
+
+def test_many_different_parallel_calls():
+    """
+    Test a graph that calls many different functions in parallel.
+
+                    input
+            /         |         \
+    prim_func_1  prim_func_2  prim_func_3
+           \         |         /
+                 prim_func_4
+    """
+
+    def get_inner_func_1():
+        x = relay.var("x", shape=(1, 4, 5, 6), dtype="int8")
+        x = relay.tanh(x)
+        x = _create_primitive_function(x)
+        return x
+
+    def get_inner_func_2():
+        x = relay.var("x", shape=(1, 4, 5, 6), dtype="int8")
+        x = relay.nn.max_pool2d(x, pool_size=(1, 1), layout="NHWC")
+        x = _create_primitive_function(x)
+        return x
+
+    def get_inner_func_3():
+        x = relay.var("x", shape=(1, 4, 5, 6), dtype="int8")
+        x = relay.abs(x)
+        x = relay.nn.relu(x)
+        x = relay.exp(x)
+        x = _create_primitive_function(x)
+        return x
+
+    def get_inner_func_4():
+        x = relay.var("x", shape=(1, 4, 5, 6), dtype="int8")
+        y = relay.var("y", shape=(1, 4, 5, 6), dtype="int8")
+        z = relay.var("z", shape=(1, 4, 5, 6), dtype="int8")
+        out = relay.concatenate([x, y, z], axis=3)
+        out = _create_primitive_function(out)
+        return out
+
+    ifm = relay.var("input", shape=(1, 4, 5, 6), dtype="int8")
+    x = relay.Call(get_inner_func_1(), [ifm])
+    y = relay.Call(get_inner_func_2(), [ifm])
+    z = relay.Call(get_inner_func_3(), [ifm])
+    a = relay.Call(get_inner_func_4(), [x, y, z])
+    mod = tvm.IRModule.from_expr(a)
+
+    expected_annotations = [
+        [(1 * 4 * 5 * 6) + (1 * 4 * 5 * 6)],
+        # output from prim_func_1 is also still alive
+        [(1 * 4 * 5 * 6) + (1 * 4 * 5 * 6) + (1 * 4 * 5 * 6)],
+        # outputs from prim_func_1 and prim_func_2 are also still alive
+        [(1 * 4 * 5 * 6) + (1 * 4 * 5 * 6) + (1 * 4 * 5 * 6) + (1 * 4 * 5 * 6)],
+        [(1 * 4 * 5 * 6) + (1 * 4 * 5 * 6) + (1 * 4 * 5 * 6) + (1 * 4 * 5 * 18)],
+    ]
+    expected_io_annotation = (1 * 4 * 5 * 6) + (1 * 4 * 5 * 18)
+    _check_used_memory_annotations(mod, expected_annotations, expected_io_annotation)
+
+
+def test_nested_branches():
+    """
+    Tests a graph with branches that also branch.
+
+             input
+            /     \
+          /        \
+    prim_func_1  prim_func_2
+                   /     \
+                  /       \
+            prim_func_3   prim_func_4
+    """
+
+    def get_generic_inner_func():
+        x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8")
+        x = relay.nn.relu(x)
+        return _create_primitive_function(x)
+
+    ifm = relay.var("input", shape=(1, 2, 2, 4), dtype="int8")
+    a = relay.Call(get_generic_inner_func(), [ifm])
+    b = relay.Call(get_generic_inner_func(), [ifm])
+    c = relay.Call(get_generic_inner_func(), [b])
+    d = relay.Call(get_generic_inner_func(), [b])
+    out = relay.concatenate([a, c, d], axis=3)
+    mod = tvm.IRModule.from_expr(out)
+
+    expected_annotations = [
+        [(1 * 2 * 2 * 4) + (1 * 2 * 2 * 4)],
+        # output from prim_func_1 is also still alive
+        [(1 * 2 * 2 * 4) + (1 * 2 * 2 * 4) + (1 * 2 * 2 * 4)],
+        # output from prim_func_1 is also still alive
+        [(1 * 2 * 2 * 4) + (1 * 2 * 2 * 4) + (1 * 2 * 2 * 4)],
+        # outputs from prim_func_1 and prim_func_3 are also still alive
+        [(1 * 2 * 2 * 4) + (1 * 2 * 2 * 4) + (1 * 2 * 2 * 4) + (1 * 2 * 2 * 4)],
+    ]
+    expected_io_annotation = (1 * 2 * 2 * 4) + (1 * 2 * 2 * 12)
+    _check_used_memory_annotations(mod, expected_annotations, expected_io_annotation)
+
+
+def test_composite_inner_function():
+    """
+    Tests the typical BYOC use case where a primitive function
+    contains a composite function.
+    """
+
+    def get_inner_func():
+        x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8")
+        x = relay.nn.max_pool2d(x, pool_size=(2, 2), layout="NHWC")
+        x = relay.Function(relay.analysis.free_vars(x), x)
+        x = x.with_attr("Composite", "my_composite_func")
+
+        y = relay.var("y", shape=(1, 2, 2, 4), dtype="int8")
+        z = relay.Call(x, [y])
+        return _create_primitive_function(z)
+
+    ifm = relay.var("input", shape=(1, 2, 2, 4), dtype="int8")
+    x = relay.Call(get_inner_func(), [ifm])
+    mod = tvm.IRModule.from_expr(x)
+
+    expected_annotations = [
+        [(1 * 2 * 2 * 4) + (1 * 1 * 1 * 4)],
+    ]
+    expected_io_annotation = (1 * 2 * 2 * 4) + (1 * 1 * 1 * 4)
+    _check_used_memory_annotations(mod, expected_annotations, expected_io_annotation)
+
+
+def test_multiple_calls_to_same_function():
+    """
+    Tests the case when there are multiple calls to the same function.
+    """
+
+    def get_inner_func():
+        x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8")
+        x = relay.nn.max_pool2d(x)
+        x = _create_primitive_function(x)
+        return x
+
+    inner_func = get_inner_func()
+    ifm = relay.var("input", shape=(1, 2, 2, 4), dtype="int8")
+    call1 = relay.Call(inner_func, [ifm])
+    call2 = relay.Call(inner_func, [call1])
+    mod = tvm.IRModule.from_expr(call2)
+
+    expected_annotations = [[2 * (1 * 2 * 2 * 4), 2 * (1 * 2 * 2 * 4)]]
+    expected_io_annotation = 2 * (1 * 2 * 2 * 4)
+    _check_used_memory_annotations(mod, expected_annotations, expected_io_annotation)
+
+
+def test_parallel_calls_to_same_function():
+    """
+    Test parallel calls to the same function.
+    """
+
+    def get_inner_func():
+        x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8")
+        x = relay.nn.max_pool2d(x)
+        x = _create_primitive_function(x)
+        return x
+
+    inner_func = get_inner_func()
+    ifm = relay.var("input", shape=(1, 2, 2, 4), dtype="int8")
+    call1 = relay.Call(inner_func, [ifm])
+    call2 = relay.Call(inner_func, [ifm])
+    concat = relay.concatenate([call1, call2], axis=0)
+    mod = tvm.IRModule.from_expr(concat)
+
+    expected_annotations = [[2 * (1 * 2 * 2 * 4), 3 * (1 * 2 * 2 * 4)]]
+    expected_io_annotation = 3 * (1 * 2 * 2 * 4)
+    _check_used_memory_annotations(mod, expected_annotations, expected_io_annotation)
+
+
+def test_parallel_calls_with_non_ifm_input():
+    """
+    Test a graph that calls many different functions in parallel where
+    the input is not the input to the function.
+
+                    y = f(x)
+            /         |         \
+       z0 = g0(y)    ...      zi = gi(y)
+           \         |         /
+                  concat
+    """
+
+    def get_inner_func_1():
+        x = relay.var("x", shape=(1, 4, 5, 6), dtype="int8")
+        x = relay.tanh(x)
+        x = _create_primitive_function(x)
+        return x
+
+    def get_inner_func_2():
+        x = relay.var("x", shape=(1, 4, 5, 6), dtype="int8")
+        x = relay.nn.max_pool2d(x, pool_size=(2, 2))
+        x = _create_primitive_function(x)
+        return x
+
+    ifm = relay.var("input", shape=(1, 4, 5, 6), dtype="int8")
+    y = relay.Call(get_inner_func_1(), [ifm])
+    g = get_inner_func_2()
+
+    no_calls = 20
+    z = [relay.Call(g, [y]) for _ in range(0, no_calls)]
+    out = relay.concatenate(z, axis=3)
+    mod = tvm.IRModule.from_expr(out)
+
+    expected_annotations = [
+        [(1 * 4 * 5 * 6) + (1 * 4 * 5 * 6)],
+        [(1 * 4 * 5 * 6) + (1 * 4 * 4 * 5) * i for i in range(1, no_calls + 1)],
+    ]
+    expected_io_annotation = (1 * 4 * 5 * 6) + (1 * 4 * 4 * (5 * no_calls))
+    _check_used_memory_annotations(mod, expected_annotations, expected_io_annotation)
+
+
+def test_dynamic_io_tensor_not_supported():
+    """
+    Test to check dynamic IO tensor error.
+    """
+
+    def get_inner_func():
+        x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8")
+        x = relay.nn.max_pool2d(x)
+        x = _create_primitive_function(x)
+        return x
+
+    ifm = relay.var("input", shape=(1, 2, 2, relay.Any()), dtype="int8")
+    call = relay.Call(get_inner_func(), [ifm])
+    mod = tvm.IRModule.from_expr(call)
+
+    err_rgx = r"AnnotateUsedMemory does not support dynamic shapes"
+    with pytest.raises(tvm.TVMError, match=err_rgx):
+        _check_used_memory_annotations(mod, [], [])
+
+
+def test_dynamic_callsite_tensor_not_supported():
+    """
+    Test to check dynamic callsite tensor error.
+    """
+
+    def get_inner_func():
+        x = relay.var("x", shape=(relay.Any(), 2, 2, 4), dtype="int8")
+        x = relay.nn.max_pool2d(x)
+        x = _create_primitive_function(x)
+        return x
+
+    ifm = relay.var("input", shape=(1, 2, 2, 4), dtype="int8")
+    call = relay.Call(get_inner_func(), [ifm])
+    mod = tvm.IRModule.from_expr(call)
+
+    err_rgx = r"AnnotateUsedMemory does not support dynamic shapes"
+    with pytest.raises(tvm.TVMError, match=err_rgx):
+        _check_used_memory_annotations(mod, [], [])