You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jr...@apache.org on 2021/11/01 21:39:45 UTC

[tvm] branch main updated: BUG: Make sure FoldConstant can inline constants underneath on_device annotations (#9367)

This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 56c4d07  BUG: Make sure FoldConstant can inline constants underneath on_device annotations (#9367)
56c4d07 is described below

commit 56c4d07ac54d6b96d94f18031ac8ab2aa263d132
Author: Mark Shields <87...@users.noreply.github.com>
AuthorDate: Mon Nov 1 14:39:12 2021 -0700

    BUG: Make sure FoldConstant can inline constants underneath on_device annotations (#9367)
    
    After device planning and conversion to ANF we can end up with:
      let %x = on_device(constant, device_type=D)
      ...
      @f(..., %x, ...)
    where the device D is not the same as the device for the let-expression
    itself. (eg D may be the CPU, %x a shape, and @f an allocation primitive
    that requires shapes to reside on the CPU). That's all consistent with the convention
    the DeviceAware* visitors expect for recovering device information.
    However, it means folding constant into @f's call site must both 'see' the
    constant underneath the on_device annotation and bring the on_device annotation
    along for the ride:
      @f(..., on_device(constant, device_type=D), ...)
    
    - Make FoldConstant be on_device aware
    - Clean things up a bit while I'm there.
    - Setup unit tests specifically for const folding with on_device annotations.
    - Replacing if __name__ == "main" drivers for units tests with official
      incantation as I encounter them.
    - Don't create on_device(on_device(...))
    - Logging changes so A/B diffs can focus on just the pass of interest.
    - Revert Index->DLDeviceType changes in the vm in case they are the cause
      of downstream problems.
---
 include/tvm/runtime/vm/vm.h                        |   4 +-
 src/ir/module.cc                                   |   9 +-
 src/relay/backend/te_compiler.cc                   |   4 +-
 src/relay/backend/vm/compiler.cc                   |  25 +-
 src/relay/backend/vm/inline_primitives.cc          |   2 +-
 src/relay/op/annotation/annotation.cc              |  10 +
 src/relay/transforms/device_aware_visitors.cc      |   2 +-
 src/relay/transforms/device_planner.cc             |   2 +-
 src/relay/transforms/fold_constant.cc              | 398 ++++++++++++---------
 src/relay/transforms/memory_alloc.cc               |   1 -
 src/runtime/cuda/cuda_device_api.cc                |  10 +-
 src/runtime/vm/executable.cc                       |   2 +-
 src/runtime/vm/memory_manager.cc                   |   8 +-
 src/runtime/vm/pooled_allocator.h                  |   6 +-
 src/runtime/vm/serialize_utils.h                   |   4 +-
 src/runtime/vm/vm.cc                               |  36 +-
 tests/python/contrib/test_tensorrt.py              |   5 +-
 tests/python/driver/tvmc/test_compiler.py          |   6 +
 tests/python/relay/test_pass_fold_constant.py      | 100 +++++-
 tests/python/relay/test_prng.py                    |   7 +-
 .../python/unittest/test_target_codegen_vulkan.py  |   5 +-
 21 files changed, 405 insertions(+), 241 deletions(-)

diff --git a/include/tvm/runtime/vm/vm.h b/include/tvm/runtime/vm/vm.h
index 039b189..ece73fc 100644
--- a/include/tvm/runtime/vm/vm.h
+++ b/include/tvm/runtime/vm/vm.h
@@ -84,11 +84,11 @@ struct VMFunction {
   /*! \brief The size of the frame for this function */
   Index register_file_size;
   /*! \brief The device type of each parameter for this function. */
-  std::vector<DLDeviceType> params_device_type;
+  std::vector<Index> params_device_type;
 
   VMFunction(const std::string& name, std::vector<std::string> params,
              const std::vector<Instruction>& instructions, Index register_file_size,
-             const std::vector<DLDeviceType> params_device_type = {})
+             const std::vector<Index> params_device_type = {})
       : name(name),
         params(params),
         instructions(instructions),
diff --git a/src/ir/module.cc b/src/ir/module.cc
index 3deb70d..8ea83cf 100644
--- a/src/ir/module.cc
+++ b/src/ir/module.cc
@@ -187,9 +187,12 @@ void WarnIfMalformed(const IRModule& mod, relay::Function func) {
   auto fv = relay::FreeVars(func);
   auto ftv = relay::FreeTypeVars(func, mod);
   // TODO(@jroesch): refactor to use diagnostic context
-  ICHECK_EQ(fv.size(), 0) << "There are free variables: " << fv << std::endl;
-  ICHECK_EQ(ftv.size(), 0) << "There are free type variables: " << fv
-                           << " in function: " << AsText(func, false);
+  ICHECK_EQ(fv.size(), 0) << "Function:" << std::endl
+                          << PrettyPrint(func) << std::endl
+                          << "contains free variables: " << fv;
+  ICHECK_EQ(ftv.size(), 0) << "Function:" << std::endl
+                           << PrettyPrint(func) << std::endl
+                           << "contains free type variables: " << fv;
 }
 
 void IRModuleNode::Add(const GlobalVar& var, const BaseFunc& f, bool update) {
diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc
index a8c27a1..ed774ec 100644
--- a/src/relay/backend/te_compiler.cc
+++ b/src/relay/backend/te_compiler.cc
@@ -562,7 +562,7 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator {
     BaseFunc prim_func = ResolveToPrimitive(new_value);
 
     if (prim_func.defined() && !prim_func->IsInstance<tir::PrimFuncNode>()) {
-      // Remember let var is bound to (possibly indirectly) to a non-tir primitive.
+      // Remember let var is bound to (possibly indirectly) a non-tir primitive.
       Function func = Downcast<Function>(prim_func);
       primitive_functions_.emplace(var, func);
     }
@@ -896,8 +896,6 @@ void UpdateFunctionMetadata(Function relay_func,
 
 IRModule LowerTE(const IRModule& module, TargetMap targets, const String& module_name,
                  std::function<void(Function)> process_fn) {
-  DLOG(INFO) << "lowering module:\n" << PrettyPrint(module);
-
   TECompiler compiler;
 
   auto updated_module = LowerTensorExpr(targets, module_name, compiler, process_fn)(module);
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index b3c1cd8..6a085ad 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -304,7 +304,13 @@ class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
       }
       VisitExpr(func);
     }
-    return VMFunction(var->name_hint, params_, instructions_, registers_num_, params_device_type);
+    std::vector<Index> params_device_type_index;
+    params_device_type_index.reserve(params_device_type.size());
+    for (auto device_type : params_device_type) {
+      params_device_type_index.push_back(static_cast<Index>(device_type));
+    }
+    return VMFunction(var->name_hint, params_, instructions_, registers_num_,
+                      params_device_type_index);
   }
 
   /*! \brief Attrs objects for each op. */
@@ -317,7 +323,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
   size_t NewRegister() { return registers_num_++; }
 
   inline void Emit(const Instruction& instr) {
-    VLOG(1) << "VMCompiler::Emit: instr=" << instr;
+    VLOG(2) << "VMCompiler::Emit: instr=" << instr;
     ICHECK((int)instr.op < 100) << "Invalid opcode " << (int)instr.op;
     switch (instr.op) {
       case Opcode::AllocADT:
@@ -703,7 +709,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
       auto global = GetRef<GlobalVar>(global_node);
       auto it = context_->global_map.find(global);
       ICHECK(it != context_->global_map.end());
-      VLOG(1) << "VisitExpr_: generating invoke for " << global->name_hint
+      VLOG(2) << "VisitExpr_: generating invoke for " << global->name_hint
               << " with func_index=" << it->second;
 
       // TODO(tvm-team):
@@ -941,12 +947,6 @@ void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Targe
     }
   }
 
-#if USE_RELAY_DEBUG
-  for (auto vm_func : exec_->functions) {
-    VLOG(1) << vm_func << "-------------";
-  }
-#endif  // USE_RELAY_DEBUG
-
   // populate constants
   for (auto data : context_.constants) {
     exec_->constants.push_back(data);
@@ -967,6 +967,12 @@ void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Targe
     exec_->primitive_map.insert({cfunc->prim_fn_var->name_hint, primitive_index++});
   }
 
+#if USE_RELAY_DEBUG
+  for (const auto& vm_func : exec_->functions) {
+    VLOG(1) << vm_func << "-------------";
+  }
+#endif  // USE_RELAY_DEBUG
+
   backend::UpdateAutoSchedulerOpWeights(context_.compiler);
 }
 
@@ -1018,6 +1024,7 @@ transform::Sequential MemoryOpt(tvm::Target host_target, TargetsMap targets) {
 
 IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetsMap& targets_arg,
                                     const Target& target_host_arg) {
+  VLOG_CONTEXT << "VMCompiler::OptimizeModule";
   TargetsMap targets = targets_arg;
   Target target_host = target_host_arg;
   CheckAndUpdateHostConsistency(&targets, &target_host);
diff --git a/src/relay/backend/vm/inline_primitives.cc b/src/relay/backend/vm/inline_primitives.cc
index 6924f25..6744248 100644
--- a/src/relay/backend/vm/inline_primitives.cc
+++ b/src/relay/backend/vm/inline_primitives.cc
@@ -87,7 +87,7 @@ struct PrimitiveInliner : ExprMutator {
     // in w(...)
     while ((var_node = op.as<VarNode>())) {
       auto var = GetRef<Var>(var_node);
-      DLOG(INFO) << "Var: " << var << std::endl;
+      VLOG(1) << "Var: " << var << std::endl;
       auto it = var_map.find(GetRef<Var>(var_node));
       if (it != var_map.end()) {
         op = it->second;
diff --git a/src/relay/op/annotation/annotation.cc b/src/relay/op/annotation/annotation.cc
index 8b00839..27b6133 100644
--- a/src/relay/op/annotation/annotation.cc
+++ b/src/relay/op/annotation/annotation.cc
@@ -76,6 +76,16 @@ Expr MaybeOnDevice(Expr expr, DLDeviceType device_type, bool is_fixed) {
     // by the function's attributes.
     return expr;
   }
+  OnDeviceProps props = GetOnDeviceProps(expr);
+  if (props.body.defined()) {
+    // Don't nest on_devices.
+    // If the inner and outer device types differ then we need to be careful:
+    //  - If the inner on_device is_fixed then it disagrees with the outer.
+    //  - If the outer on_device is_fixed then it implies a hidden device_copy
+    // Otherwise just use the inner device type and ignore the outer.
+    ICHECK(props.device_type == device_type || (!is_fixed && !props.is_fixed));
+    return OnDevice(props.body, device_type, is_fixed || props.is_fixed);
+  }
   return OnDevice(expr, device_type, is_fixed);
 }
 
diff --git a/src/relay/transforms/device_aware_visitors.cc b/src/relay/transforms/device_aware_visitors.cc
index 28aeab6..38c3305 100644
--- a/src/relay/transforms/device_aware_visitors.cc
+++ b/src/relay/transforms/device_aware_visitors.cc
@@ -262,7 +262,7 @@ Expr DeviceAwareExprMutator::VisitExpr_(const CallNode* call_node) {
     Expr expr = VisitExpr(props.body);
     // Leaving lexical scope of "on_device" call.
     PopDeviceType();
-    return OnDevice(expr, props.device_type, props.is_fixed);
+    return MaybeOnDevice(expr, props.device_type, props.is_fixed);
   } else {
     return DeviceAwareVisitExpr_(call_node);
   }
diff --git a/src/relay/transforms/device_planner.cc b/src/relay/transforms/device_planner.cc
index dc61e79..83429a9 100644
--- a/src/relay/transforms/device_planner.cc
+++ b/src/relay/transforms/device_planner.cc
@@ -18,7 +18,7 @@
  */
 
 /*!
- * \file src/relay/analysis/device_planner.cc
+ * \file src/relay/transforms/device_planner.cc
  * \brief Determines a unique device to hold the result of every Relay sub-expression.
  *
  * We say a Relay expression E is 'on device D' if the result of executing E is stored on D.
diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc
index d545518..c48a9b3 100644
--- a/src/relay/transforms/fold_constant.cc
+++ b/src/relay/transforms/fold_constant.cc
@@ -21,6 +21,7 @@
  * \file constant_folding.cc
  */
 #include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/annotation.h>
 #include <tvm/relay/attrs/transform.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/interpreter.h>
@@ -30,68 +31,80 @@
 #include <tvm/runtime/ndarray.h>
 #include <tvm/runtime/object.h>
 
-#include "pattern_utils.h"
+#include "../op/annotation/annotation.h"
+#include "./device_aware_visitors.h"
+#include "./pattern_utils.h"
 
 namespace tvm {
 namespace relay {
+namespace transform {
 
-using FInterpreter = runtime::TypedPackedFunc<ObjectRef(Expr)>;
-
-class ConstantChecker : private ExprVisitor {
- public:
-  // Check whether an expression is constant. The results are memoized.
-  bool Check(const Expr& expr) {
-    // The `ConstantNode` case is common enough that we check directly for the
-    // case here, to avoid the time overhead of dispatching through the vtable
-    // and the space overhead of memoizing always-true results.
-    if (expr.as<ConstantNode>()) {
-      return true;
-    }
-    const auto it = memo_.find(expr);
-    if (it != memo_.end()) return it->second;
-    VisitExpr(expr);
-    return memo_[expr];  // return memoized result or the default value false
-  }
+namespace {
+/*!
+ * \brief Returns whether \p expr is a literal \p Constant, optionally wrapped by an "on_device"
+ * annotation CallNode (which serves only to associate a device to the constant and has no
+ * operational effect).
+ */
+bool IsSimpleConstant(const Expr& expr) {
+  return AsIgnoringOnDevice<ConstantNode>(expr) != nullptr;
+}
 
- private:
-  std::unordered_map<Expr, bool, ObjectPtrHash, ObjectPtrEqual> memo_;
-
-  void VisitExpr_(const TupleNode* n) final {
-    bool result = true;
-    for (const auto& field : n->fields) {
-      if (!Check(field)) {
-        result = false;
-        break;
-      }
-    }
-    memo_[GetRef<Tuple>(n)] = result;
+/*!
+ * \brief Returns whether \p expr \p IsSimpleConstant directly or is a tuple of
+ * \p IsComplexConstant expressions.
+ */
+bool IsComplexConstant(const Expr& expr) {
+  if (IsSimpleConstant(expr)) {
+    return true;
+  } else if (const auto* tuple_node = AsIgnoringOnDevice<TupleNode>(expr)) {
+    return std::all_of(tuple_node->fields.begin(), tuple_node->fields.end(), IsComplexConstant);
+  } else {
+    return false;
   }
-};
-
-bool ConstantCheck(const Expr& e) { return ConstantChecker().Check(e); }
-
-TVM_REGISTER_GLOBAL("relay.analysis.check_constant").set_body_typed(ConstantCheck);
+}
 
 // TODO(tvm-team) consider combine dead-code with constant folder.
 // or make a more powerful partial evaluator.
 class ConstantFolder : public MixedModeMutator {
  public:
   explicit ConstantFolder(IRModule module)
-      : module_(module),
+      : module_(std::move(module)),
         device_copy_op_(Op::Get("device_copy")),
         shape_of_op_(Op::Get("shape_of")),
         vm_shape_of_op_(Op::Get("vm.shape_of")),
         cast_op_(Op::Get("cast")),
         ndarray_size_op_(Op::Get("ndarray_size")) {}
 
-  using MixedModeMutator::VisitExpr_;
+ private:
+  using ExprMutator::VisitExpr_;
 
-  Expr VisitExpr_(const LetNode* op) final {
+  Expr VisitExpr_(const LetNode* let_node) final {
     auto pre_visit = [this](const LetNode* op) {
       // Rely on the Memoizer to cache pre-visit values
-      Expr value = this->Mutate(op->value);
-      if (value.as<ConstantNode>()) {
-        this->memo_[op->var] = value;
+      Expr new_value = Mutate(op->value);
+      if (IsSimpleConstant(new_value)) {
+        // Inline new value (along with any on_device annotation wrapping it) at all occurrences of
+        // the variable.
+        //
+        // We need to retain any "on_device" annotation so that downstream 'device aware'
+        // passes can still retrieve the device for the constant in its new position(s). Eg:
+        //   def @f(..., result_device_type=D) {
+        //     let %x = on_device(... something we eval to a constant..., device_type=E)
+        //     @f(..., %x, ...)
+        //   }
+        // Here the default device is D, whereas the argument %x to @f is on E (and @f expects
+        // that). No on_device annotation is required in the call according to the convention used
+        // by the device-aware visitors.
+        //
+        // However once we've inlined the constant we need to insert an on_device, again to
+        // respect the convention used by the device-aware visitors.
+        //   def @f(..., result_device_type=D) {
+        //     @f(..., on_device(...the constant..., device_type=E), ...)
+        //   }
+        VLOG(1) << "Replacing let-binding for " << op->var->name_hint()
+                << " with constant:" << std::endl
+                << PrettyPrint(new_value);
+        memo_[op->var] = new_value;
       } else {
         this->Mutate(op->var);
       }
@@ -99,116 +112,117 @@ class ConstantFolder : public MixedModeMutator {
     auto post_visit = [this](const LetNode* op) {
       Expr expr = GetRef<Expr>(op);
       // Rely on the Memoizer to cache pre-visit values
-      Expr value = this->Mutate(op->value);
-      if (value.as<ConstantNode>()) {
-        this->memo_[expr] = this->Mutate(op->body);
+      Expr new_value = this->Mutate(op->value);
+      if (IsSimpleConstant(new_value)) {
+        // The let-bound value has been inlined, drop the let-binding itself.
+        this->memo_[expr] = Mutate(op->body);
       } else {
-        Var var = Downcast<Var>(this->Mutate(op->var));
-        Expr body = this->Mutate(op->body);
-        if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) {
+        Var new_var = Downcast<Var>(this->Mutate(op->var));
+        Expr new_body = this->Mutate(op->body);
+        if (new_var.same_as(op->var) && new_value.same_as(op->value) &&
+            new_body.same_as(op->body)) {
           this->memo_[expr] = expr;
         } else {
-          this->memo_[expr] = Let(var, value, body);
+          this->memo_[expr] = Let(new_var, new_value, new_body, op->span);
         }
       }
     };
-    ExpandANormalForm(op, pre_visit, post_visit);
-    return memo_[GetRef<Expr>(op)];
+    ExpandANormalForm(let_node, pre_visit, post_visit);
+    return memo_[GetRef<Expr>(let_node)];
   }
 
-  bool inside_primitive = false;
-  Expr VisitExpr_(const FunctionNode* op) final {
-    if (op->HasNonzeroAttr(attr::kPrimitive)) {
-      ICHECK_EQ(inside_primitive, false);
-      inside_primitive = true;
-      auto ret = ExprMutator::VisitExpr_(op);
-      inside_primitive = false;
+  Expr VisitExpr_(const FunctionNode* function_node) final {
+    if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
+      ICHECK_EQ(inside_primitive_, false);
+      inside_primitive_ = true;
+      auto ret = ExprMutator::VisitExpr_(function_node);
+      inside_primitive_ = false;
       return ret;
     } else {
-      return ExprMutator::VisitExpr_(op);
+      return ExprMutator::VisitExpr_(function_node);
     }
   }
 
-  Expr VisitExpr_(const IfNode* op) final {
-    auto new_cond = ExprMutator::VisitExpr(op->cond);
-    if (auto const_cond = new_cond.as<ConstantNode>()) {
-      if (reinterpret_cast<uint8_t*>(const_cond->data->data)[0]) {
-        return ExprMutator::VisitExpr(op->true_branch);
-      } else {
-        return ExprMutator::VisitExpr(op->false_branch);
-      }
+  Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final {
+    Call pre_call = GetRef<Call>(pre_call_node);
+    if (inside_primitive_) {
+      return pre_call;
     }
-    return ExprMutator::VisitExpr_(op);
-  }
 
-  Expr Rewrite_(const CallNode* call, const Expr& post) final {
-    if (inside_primitive) {
-      return GetRef<Expr>(call);
-    }
-    static auto op_stateful = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful");
-
-    auto origin_args = call->args;
-    call = post.as<CallNode>();
-    // We don't constant fold function with zero arguments.
-    // This is a heuristic that is useful.
-    // For example it is harmful to fold ones(shape=(4, 5)).
-    if (call->args.size() == 0) return post;
-    const OpNode* op = call->op.as<OpNode>();
-    if (op == nullptr) return post;
-    // skip stateful ops.
-    if (op_stateful.get(GetRef<Op>(op), false)) return post;
-    // Try to evaluate shape_of op
-    if (call->op == shape_of_op_ || call->op == vm_shape_of_op_) {
-      return EvaluateShapeOf(post, origin_args, call->attrs);
-    }
+    Call post_call = Downcast<Call>(post);
 
-    if (call->op == ndarray_size_op_) {
-      return EvaluateNdarraySize(post, origin_args, call->attrs);
+    if (post_call->args.empty()) {
+      // We don't constant fold function with zero arguments.
+      // This is a heuristic that is useful.
+      // For example it is harmful to fold ones(shape=(4, 5)).
+      return std::move(pre_call);
     }
 
-    // We should think about potentially constant evaluation over these ops too.
     static auto fnoncomputational = Op::GetAttrMap<TNonComputational>("TNonComputational");
-    if (const auto* call_node = call->op.as<OpNode>()) {
-      Op op = GetRef<Op>(call_node);
-      if ((fnoncomputational.count(op) && fnoncomputational[op]) || (call->op == device_copy_op_)) {
-        return GetRef<Call>(call);
-      }
-    }
 
-    bool all_const_args = true;
-    for (Expr arg : call->args) {
-      if (!checker_.Check(arg)) {
-        all_const_args = false;
-      }
+    const auto* op_node = post_call->op.as<OpNode>();
+    if (op_node == nullptr) {
+      // Only evaluate primitives.
+      return std::move(post_call);
     }
-    if (all_const_args) {
-      return ConstEvaluate(post);
-    } else {
-      return post;
+    Op op = GetRef<Op>(op_node);
+    static auto op_stateful = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful");
+    if (op_stateful.get(op, false)) {
+      // skip stateful ops.
+      return std::move(post_call);
     }
+    // Try to evaluate shape_of and ndarray_size ops
+    // Use the original call rather than new_call here since it still has valid checked_type
+    // fields. These operators don't care about the value of their argument anyway.
+    if (Optional<Expr> opt_result = EvaluateShapeOf(pre_call)) {
+      return opt_result.value();
+    }
+    // Use the original call rather than new_call here since it still has valid checked_type
+    // fields. This operator doesn't care about the value of its argument anyway.
+    if (Optional<Expr> opt_result = EvaluateNdarraySize(pre_call)) {
+      return opt_result.value();
+    }
+    if ((fnoncomputational.count(op) && fnoncomputational[op]) || op == device_copy_op_ ||
+        op == shape_of_op_ || op == vm_shape_of_op_ || op == ndarray_size_op_) {
+      // We should think about potentially constant evaluation over these ops too.
+      return std::move(post_call);
+    }
+    if (!std::all_of(post_call->args.begin(), post_call->args.end(), IsComplexConstant)) {
+      // At least one non-constant argument.
+      return std::move(post_call);
+    }
+    // During evaluation we have obviously lost all on_device annotations. However any
+    // on_device wrapping this call will be left in place.
+    return ConstEvaluate(post_call);
   }
 
-  Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) final {
-    op = post.as<TupleGetItemNode>();
-    if (const auto* tuple = op->tuple.as<TupleNode>()) {
-      return tuple->fields[op->index];
-    } else {
-      return post;
+  Expr VisitExpr_(const IfNode* if_node) final {
+    If new_if = Downcast<If>(ExprMutator::VisitExpr_(if_node));
+    if (const auto* const_node = AsIgnoringOnDevice<ConstantNode>(new_if->cond)) {
+      if (reinterpret_cast<uint8_t*>(const_node->data->data)[0]) {
+        return new_if->true_branch;
+      } else {
+        return new_if->false_branch;
+      }
     }
+    return std::move(new_if);
   }
 
- private:
-  // Internal constant checker
-  ConstantChecker checker_;
-  // Module
-  IRModule module_;
-
-  // Cache the following ops for equivalence checking in this pass.
-  const Op& device_copy_op_;
-  const Op& shape_of_op_;
-  const Op& vm_shape_of_op_;
-  const Op& cast_op_;
-  const Op& ndarray_size_op_;
+  Expr Rewrite_(const TupleGetItemNode* tuple_get_item_node,
+                const Expr& post_tuple_get_item) final {
+    const auto* post_tuple_get_item_node = post_tuple_get_item.as<TupleGetItemNode>();
+    if (const auto* tuple_node = AsIgnoringOnDevice<TupleNode>(post_tuple_get_item_node->tuple)) {
+      Expr result = tuple_node->fields[tuple_get_item_node->index];
+      OnDeviceProps props = GetOnDeviceProps(post_tuple_get_item_node->tuple);
+      if (props.body.defined()) {
+        // (on_device((x, y, z), device_type=D).1 ==> on_device(y, device_type=D)
+        return MaybeOnDevice(result, props.device_type, props.is_fixed);
+      } else {
+        return result;
+      }
+    }
+    return std::move(post_tuple_get_item);
+  }
 
   // Convert value to expression.
   Expr ObjectToExpr(const ObjectRef& value) {
@@ -224,35 +238,53 @@ class ConstantFolder : public MixedModeMutator {
       return Tuple(fields);
     } else {
       LOG(FATAL) << "Cannot handle " << value->GetTypeKey();
-      return Expr();
+      return {};
     }
   }
+
   // Constant evaluate an expression.
-  Expr ConstEvaluate(Expr expr) {
+  Expr ConstEvaluate(const Expr& expr) {
+    VLOG_CONTEXT << "ConstEvaluate";
+    VLOG(1) << "Evaluating :" << std::endl << PrettyPrint(expr);
+
+    // We'll invoke the interpreter using the generic CPU device and target. Technically there's
+    // no guarantee the results we bitwise equal what we'd get on the true device, however to
+    // support cross-compilation we don't want to assume the true device is available.
     Device dev;
     dev.device_type = kDLCPU;
     dev.device_id = 0;
     Target target = Target("llvm");
 
-    // use a fresh build context in case we are already in a build context.
+    // Use a fresh build context in case we are already in a build context.
     // needed for both execution and creation(due to JIT)
     With<transform::PassContext> fresh_build_ctx(transform::PassContext::Create());
 
-    return ObjectToExpr(Eval(expr, module_->type_definitions, module_->Imports(), dev, target));
+    Expr result =
+        ObjectToExpr(Eval(expr, module_->type_definitions, module_->Imports(), dev, target));
+    VLOG(1) << "Evaluated to constant:" << std::endl << PrettyPrint(result);
+    return result;
   }
 
-  // Evaluate a call to the shape_of operator for tensors with constant
-  // shapes.
-  Expr EvaluateShapeOf(Expr expr, Array<Expr> args, Attrs attrs) {
-    Expr input = args[0];
-    const auto* param = attrs.as<ShapeOfAttrs>();
+  /*!
+   * \brief Returns constant shape result of \p call if it of form \p shape_of(e) and \p e has
+   * a non-dynamic tensor shape. Returns null otherwise.
+   */
+  Optional<Expr> EvaluateShapeOf(const Call& call) {
+    if (call->op != shape_of_op_ && call->op != vm_shape_of_op_) {
+      return {};
+    }
+
+    VLOG(1) << "Evaluating for shape_of:" << std::endl << PrettyPrint(call);
+    ICHECK_EQ(call->args.size(), 1);
+    const auto* param = call->attrs.as<ShapeOfAttrs>();
     ICHECK(param != nullptr);
+    Expr input = call->args[0];
 
     tvm::Array<IndexExpr> ishape;
-    if (auto opt = GetConstantShape(input)) {
-      ishape = opt.value();
+    if (Optional<tvm::Array<IndexExpr>> opt_shape = GetConstantShape(input)) {
+      ishape = opt_shape.value();
     } else {
-      return expr;
+      return {};
     }
 
     // Get the constant shape
@@ -261,26 +293,26 @@ class ConstantFolder : public MixedModeMutator {
     dev.device_id = 0;
     runtime::NDArray value;
     DLDataType cdtype = DataType::Int(32);
-    if (ishape.size() == 0) {
+    if (ishape.empty()) {
       value = runtime::NDArray::Empty({}, cdtype, dev);
     } else {
       ICHECK_NE(ishape.size(), 0);
       std::vector<int64_t> cshape = {static_cast<int64_t>(ishape.size())};
       value = runtime::NDArray::Empty(cshape, cdtype, dev);
-      int32_t* dims = static_cast<int32_t*>(value->data);
+      auto* dims = static_cast<int32_t*>(value->data);
       using ::tvm::tir::IntImmNode;
       for (size_t i = 0; i < ishape.size(); ++i) {
-        if (const IntImmNode* dim = ishape[i].as<IntImmNode>()) {
+        if (const auto* dim = ishape[i].as<IntImmNode>()) {
           dims[i] = dim->value;
         } else {
-          return expr;
+          return {};
         }
       }
     }
 
     Constant shape = Downcast<Constant>(ObjectToExpr(value));
 
-    if (shape->data.Shape().size() == 0 && GetScalarFromConstant<int32_t>(shape) == 0) {
+    if (shape->data.Shape().empty() && GetScalarFromConstant<int32_t>(shape) == 0) {
       auto ndarray = runtime::NDArray::Empty({}, cdtype, dev);
       shape = Constant(ndarray);
     }
@@ -288,18 +320,25 @@ class ConstantFolder : public MixedModeMutator {
     return CastValue(shape, param->dtype);
   }
 
-  // Evaluate a call to the ndarray_size operator for tensors with constant
-  // shapes.
-  Expr EvaluateNdarraySize(Expr expr, Array<Expr> args, Attrs attrs) {
-    Expr input = args[0];
-    const auto* param = attrs.as<NdarraySizeAttrs>();
+  /*!
+   * \brief Returns the constant NDArray size of result of \p call if it is of the form
+   * \p ndarray_size(e) and \p e has non-dynamic tensor type. Returns null otherwise.
+   */
+  Optional<Expr> EvaluateNdarraySize(const Call& call) {
+    if (call->op != ndarray_size_op_) {
+      return {};
+    }
+    VLOG(1) << "Evaluating for ndarray_size:" << std::endl << PrettyPrint(call);
+    ICHECK_EQ(call->args.size(), 1);
+    Expr input = call->args[0];
+    const auto* param = call->attrs.as<NdarraySizeAttrs>();
     ICHECK(param != nullptr);
 
     tvm::Array<IndexExpr> ishape;
-    if (auto opt = GetConstantShape(input)) {
-      ishape = opt.value();
+    if (Optional<tvm::Array<IndexExpr>> opt_shape = GetConstantShape(input)) {
+      ishape = opt_shape.value();
     } else {
-      return expr;
+      return {};
     }
 
     // Get the constant size
@@ -309,17 +348,17 @@ class ConstantFolder : public MixedModeMutator {
     runtime::NDArray value;
     DLDataType cdtype = DataType::Int(32);
     value = runtime::NDArray::Empty({}, cdtype, dev);
-    int32_t* data = static_cast<int32_t*>(value->data);
-    if (ishape.size() == 0) {
+    auto* data = static_cast<int32_t*>(value->data);
+    if (ishape.empty()) {
       *data = 0;
     } else {
       *data = 1;
       using ::tvm::tir::IntImmNode;
       for (size_t i = 0; i < ishape.size(); ++i) {
-        if (const IntImmNode* dim = ishape[i].as<IntImmNode>()) {
+        if (const auto* dim = ishape[i].as<IntImmNode>()) {
           *data *= dim->value;
         } else {
-          return expr;
+          return {};
         }
       }
     }
@@ -337,31 +376,57 @@ class ConstantFolder : public MixedModeMutator {
   }
 
   Optional<tvm::Array<IndexExpr>> GetConstantShape(const Expr& input) {
-    tvm::Array<IndexExpr> ishape;
-    if (const ConstantNode* op = input.as<ConstantNode>()) {
-      ishape = op->tensor_type()->shape;
+    if (const auto* const_node = AsIgnoringOnDevice<ConstantNode>(input)) {
+      // TODO(mbs): This is not necessary since we only ever ask for the shapes for
+      // pre-rewritten expressions which will always have a checked_type.
+      return const_node->tensor_type()->shape;
     } else if (input->checked_type_.defined()) {
-      ishape = input->checked_type().as<TensorTypeNode>()->shape;
+      return input->checked_type().as<TensorTypeNode>()->shape;
     } else {
-      return Optional<tvm::Array<IndexExpr>>(nullptr);
+      return {};
     }
-
-    return Optional<tvm::Array<IndexExpr>>(ishape);
   }
+
+  // Module
+  IRModule module_;
+
+  // Cache the following ops for equivalence checking in this pass.
+  const Op& device_copy_op_;
+  const Op& shape_of_op_;
+  const Op& vm_shape_of_op_;
+  const Op& cast_op_;
+  const Op& ndarray_size_op_;
+
+  // True if currently within a "primitive" Relay Function.
+  bool inside_primitive_ = false;
 };
 
-Expr FoldConstant(const Expr& expr, const IRModule& mod) {
-  return ConstantFolder(mod).Mutate(expr);
-}
+}  // namespace
 
-TVM_REGISTER_GLOBAL("relay._transform.FoldConstantExpr").set_body_typed(FoldConstant);
+TVM_REGISTER_GLOBAL("relay.analysis.check_constant").set_body_typed(IsComplexConstant);
 
-namespace transform {
+/*!
+ * \brief Returns \p expr with any constants expressions evaluated and let-bound constants
+ * inlined. Returns \p expr unchanged if no change.
+ *
+ * CAUTION: The importers rely on this function returning \p expr unchanged to preserve sharing
+ * from their p.o.v. Furthermore, this function can be called before conversion to ANF so
+ * we must avoid all recursion.
+ */
+Expr FoldConstantExpr(const Expr& expr, const IRModule& mod) {
+  VLOG_CONTEXT << "FoldConstantExpr";
+  VLOG(1) << "folding:" << std::endl << PrettyPrint(expr);
+  Expr result = ConstantFolder(mod).VisitExpr(expr);
+  VLOG(1) << "folded to:" << std::endl << PrettyPrint(result);
+  return result;
+}
+
+TVM_REGISTER_GLOBAL("relay._transform.FoldConstantExpr").set_body_typed(FoldConstantExpr);
 
 Pass FoldConstant() {
   runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
       [=](Function f, IRModule m, PassContext pc) {
-        return Downcast<Function>(FoldConstant(f, m));
+        return Downcast<Function>(FoldConstantExpr(f, m));
       };
   return CreateFunctionPass(pass_func, 2, "FoldConstant", {});
 }
@@ -369,6 +434,5 @@ Pass FoldConstant() {
 TVM_REGISTER_GLOBAL("relay._transform.FoldConstant").set_body_typed(FoldConstant);
 
 }  // namespace transform
-
 }  // namespace relay
 }  // namespace tvm
diff --git a/src/relay/transforms/memory_alloc.cc b/src/relay/transforms/memory_alloc.cc
index 71917c3..dd582de 100644
--- a/src/relay/transforms/memory_alloc.cc
+++ b/src/relay/transforms/memory_alloc.cc
@@ -415,7 +415,6 @@ Pass ManifestAlloc(Target target_host, Map<tvm::Integer, tvm::Target> targets) {
   CheckAndUpdateHostConsistency(&targets, &target_host);
   return tvm::transform::CreateModulePass(
       [=](IRModule mod, const PassContext& pass_ctx) {
-        DLOG(INFO) << "tvm::relay::transform::ManifestAlloc";
         // We need to mutate module, therefore making a copy of it.
         mod.CopyOnWrite();
         mod->ImportFromStd("core.rly");
diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc
index 33a87c9..b4d7b41 100644
--- a/src/runtime/cuda/cuda_device_api.cc
+++ b/src/runtime/cuda/cuda_device_api.cc
@@ -112,14 +112,14 @@ class CUDADeviceAPI final : public DeviceAPI {
     ICHECK_EQ(256 % alignment, 0U) << "CUDA space is aligned at 256 bytes";
     void* ret;
     if (dev.device_type == kDLCUDAHost) {
-      DLOG(INFO) << "allocating " << nbytes << "bytes on host";
+      VLOG(1) << "allocating " << nbytes << "bytes on host";
       CUDA_CALL(cudaMallocHost(&ret, nbytes));
     } else {
       CUDA_CALL(cudaSetDevice(dev.device_id));
       size_t free_mem, total_mem;
       CUDA_CALL(cudaMemGetInfo(&free_mem, &total_mem));
-      DLOG(INFO) << "allocating " << nbytes << " bytes on device, with " << free_mem
-                 << " bytes currently free out of " << total_mem << " bytes available";
+      VLOG(1) << "allocating " << nbytes << " bytes on device, with " << free_mem
+              << " bytes currently free out of " << total_mem << " bytes available";
       CUDA_CALL(cudaMalloc(&ret, nbytes));
     }
     return ret;
@@ -127,11 +127,11 @@ class CUDADeviceAPI final : public DeviceAPI {
 
   void FreeDataSpace(Device dev, void* ptr) final {
     if (dev.device_type == kDLCUDAHost) {
-      DLOG(INFO) << "freeing host memory";
+      VLOG(1) << "freeing host memory";
       CUDA_CALL(cudaFreeHost(ptr));
     } else {
       CUDA_CALL(cudaSetDevice(dev.device_id));
-      DLOG(INFO) << "freeing device memory";
+      VLOG(1) << "freeing device memory";
       CUDA_CALL(cudaFree(ptr));
     }
   }
diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc
index a5e7d25..15d2aa0 100644
--- a/src/runtime/vm/executable.cc
+++ b/src/runtime/vm/executable.cc
@@ -308,7 +308,7 @@ void Executable::SavePrimitiveOpNames(dmlc::Stream* strm) {
 VMInstructionSerializer SerializeInstruction(const Instruction& instr) {
   std::vector<Index> fields;
   // Save the opcode.
-  DLOG(INFO) << "Serializing: " << instr << std::endl;
+  VLOG(1) << "Serializing: " << instr << std::endl;
   switch (instr.op) {
     case Opcode::Move: {
       // Number of fields = 2
diff --git a/src/runtime/vm/memory_manager.cc b/src/runtime/vm/memory_manager.cc
index 410f6c2..22afcce 100644
--- a/src/runtime/vm/memory_manager.cc
+++ b/src/runtime/vm/memory_manager.cc
@@ -119,14 +119,14 @@ Allocator* MemoryManager::GetOrCreateAllocator(Device dev, AllocatorType type) {
     std::unique_ptr<Allocator> alloc;
     switch (type) {
       case kNaive: {
-        DLOG(INFO) << "New naive allocator for " << DeviceName(dev.device_type) << "("
-                   << dev.device_id << ")";
+        VLOG(1) << "New naive allocator for " << DeviceName(dev.device_type) << "(" << dev.device_id
+                << ")";
         alloc.reset(new NaiveAllocator(dev));
         break;
       }
       case kPooled: {
-        DLOG(INFO) << "New pooled allocator for " << DeviceName(dev.device_type) << "("
-                   << dev.device_id << ")";
+        VLOG(1) << "New pooled allocator for " << DeviceName(dev.device_type) << "("
+                << dev.device_id << ")";
         alloc.reset(new PooledAllocator(dev));
         break;
       }
diff --git a/src/runtime/vm/pooled_allocator.h b/src/runtime/vm/pooled_allocator.h
index c282eb0..e5f2369 100644
--- a/src/runtime/vm/pooled_allocator.h
+++ b/src/runtime/vm/pooled_allocator.h
@@ -67,7 +67,7 @@ class PooledAllocator final : public Allocator {
     }
 
     used_memory_.fetch_add(size, std::memory_order_relaxed);
-    DLOG(INFO) << "allocate " << size << " B, used memory " << used_memory_ << " B";
+    VLOG(1) << "allocate " << size << " B, used memory " << used_memory_ << " B";
     return buf;
   }
 
@@ -77,7 +77,7 @@ class PooledAllocator final : public Allocator {
       memory_pool_.emplace(buffer.size, std::vector<Buffer>{});
     }
     memory_pool_.at(buffer.size).push_back(buffer);
-    DLOG(INFO) << "reclaim buffer " << buffer.size;
+    VLOG(1) << "reclaim buffer " << buffer.size;
   }
 
   size_t UsedMemory() const override { return used_memory_.load(std::memory_order_relaxed); }
@@ -93,7 +93,7 @@ class PooledAllocator final : public Allocator {
     }
     memory_pool_.clear();
     used_memory_ = 0;
-    DLOG(INFO) << "release all buffers";
+    VLOG(1) << "release all buffers";
   }
 
  private:
diff --git a/src/runtime/vm/serialize_utils.h b/src/runtime/vm/serialize_utils.h
index cbcdb1b..b4a1080 100644
--- a/src/runtime/vm/serialize_utils.h
+++ b/src/runtime/vm/serialize_utils.h
@@ -59,13 +59,13 @@ struct VMFunctionSerializer {
   /*! \brief The parameters of the VMFunction. */
   std::vector<std::string> params;
   /*! \brief The device type of each parameter of the VMFunction. */
-  std::vector<DLDeviceType> params_device_type;
+  std::vector<Index> params_device_type;
 
   VMFunctionSerializer() = default;
 
   VMFunctionSerializer(const std::string& name, Index register_file_size, size_t num_instructions,
                        const std::vector<std::string>& params,
-                       const std::vector<DLDeviceType>& params_device_type)
+                       const std::vector<Index>& params_device_type)
       : name(name),
         register_file_size(register_file_size),
         num_instructions(num_instructions),
diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc
index addd5ca..b903f79 100644
--- a/src/runtime/vm/vm.cc
+++ b/src/runtime/vm/vm.cc
@@ -236,7 +236,7 @@ void VirtualMachine::SetInput(std::string func_name, TVMArgs args, int offset) {
       << "The number of provided parameters doesn't match the number of assigned devices";
   std::vector<ObjectRef> func_args(param_names.size());
   for (int i = offset; i < args.size(); ++i) {
-    DLDeviceType device_type = vm_func.params_device_type[i - offset];
+    Index device_type = vm_func.params_device_type[i - offset];
     Device dev = GetDevice(device_type);
 
     if (args[i].type_code() == kTVMDLTensorHandle) {
@@ -284,20 +284,20 @@ Index VirtualMachine::PopFrame() {
 }
 
 void VirtualMachine::InvokeGlobal(const VMFunction& func, const std::vector<ObjectRef>& args) {
-  DLOG(INFO) << "Invoking global " << func.name << " " << args.size();
+  VLOG(2) << "Invoking global " << func.name << " " << args.size();
 
   PushFrame(func.params.size(), this->pc_ + 1, func);
   for (size_t i = 0; i < args.size(); ++i) {
     WriteRegister(i, args[i]);
   }
-  DLOG(INFO) << "func.params= " << func.params.size();
+  VLOG(2) << "func.params= " << func.params.size();
 
   code_ = func.instructions.data();
   pc_ = 0;
 }
 
 ObjectRef VirtualMachine::Invoke(const VMFunction& func, const std::vector<ObjectRef>& args) {
-  DLOG(INFO) << "Executing Function: " << std::endl << func;
+  VLOG(2) << "Executing Function: " << std::endl << func;
 
   InvokeGlobal(func, args);
   RunLoop();
@@ -309,7 +309,7 @@ ObjectRef VirtualMachine::Invoke(const std::string& name, const std::vector<Obje
   auto it = exec_->global_map.find(name);
   ICHECK(it != exec_->global_map.end()) << "Cannot find function " << name << " in the executable";
   auto func_index_ = it->second;
-  DLOG(INFO) << "Invoke Global " << name << " at index " << func_index_;
+  VLOG(2) << "Invoke Global " << name << " at index " << func_index_;
   return Invoke(exec_->functions[func_index_], args);
 }
 
@@ -445,7 +445,7 @@ void VirtualMachine::RunLoop() {
   while (true) {
   main_loop:
     auto const& instr = code_[this->pc_];
-    DLOG(INFO) << "Executing(" << pc_ << "): " << instr;
+    VLOG(2) << "Executing(" << pc_ << "): " << instr;
 
     switch (instr.op) {
       case Opcode::Move: {
@@ -500,13 +500,13 @@ void VirtualMachine::RunLoop() {
         goto main_loop;
       }
       case Opcode::InvokePacked: {
-        DLOG(INFO) << "InvokedPacked " << instr.packed_index << " arity=" << instr.arity;
+        VLOG(2) << "InvokedPacked " << instr.packed_index << " arity=" << instr.arity;
         ICHECK_LE(instr.packed_index, packed_funcs_.size());
         const auto& func = packed_funcs_[instr.packed_index];
         const auto& arity = instr.arity;
         std::vector<ObjectRef> args;
         for (Index i = 0; i < arity; ++i) {
-          DLOG(INFO) << "arg" << i << " $" << instr.packed_args[i];
+          VLOG(2) << "arg" << i << " $" << instr.packed_args[i];
           auto arg = ReadRegister(instr.packed_args[i]);
           args.push_back(arg);
         }
@@ -579,6 +579,18 @@ void VirtualMachine::RunLoop() {
         auto storage_obj = ReadRegister(instr.alloc_tensor.storage);
         auto offset = LoadScalarInt(instr.alloc_tensor.offset);
         auto storage = Downcast<Storage>(storage_obj);
+#if TVM_LOG_DEBUG
+        std::ostringstream os;
+        os << "AllocTensor: ";
+        os << "offset=" << offset;
+        os << ", shape=[";
+        for (auto i : shape) {
+          os << i << ",";
+        }
+        os << "]";
+        os << ", dtype=" << DLDataType2String(instr.alloc_tensor.dtype);
+        VLOG(2) << os.str();
+#endif
         auto obj = storage->AllocNDArray(offset, shape, instr.alloc_tensor.dtype);
 
         WriteRegister(instr.dst, obj);
@@ -625,17 +637,15 @@ void VirtualMachine::RunLoop() {
         OpStartHook(instr);
         auto size = LoadScalarInt(instr.alloc_storage.allocation_size);
         auto alignment = instr.alloc_storage.alignment;
-
-        DLOG(INFO) << "AllocStorage: allocation_size=" << size << ", alignment=" << alignment
-                   << ", dtype_hint=" << DLDataType2String(instr.alloc_storage.dtype_hint)
-                   << ", device_type=" << instr.alloc_storage.device_type;
-
         auto storage_obj = SimpleObjAllocator().make_object<StorageObj>();
         auto dev_type = instr.alloc_storage.device_type;
         ICHECK_LT(static_cast<size_t>(dev_type), allocators_.size())
             << "Memory allocator for device " << dev_type << " has not been initialized";
         auto* alloc = allocators_[dev_type];
         ICHECK(alloc) << "Did you forget to init the VirtualMachine with devices?";
+        VLOG(2) << "AllocStorage: allocation_size=" << size << ", alignment=" << alignment
+                << ", dtype_hint=" << DLDataType2String(instr.alloc_storage.dtype_hint)
+                << ", device_type=" << instr.alloc_storage.device_type;
         storage_obj->buffer = alloc->Alloc(size, alignment, instr.alloc_storage.dtype_hint);
         Storage storage(storage_obj);
         WriteRegister(instr.dst, storage);
diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py
index df4234e..0ee5ce3 100644
--- a/tests/python/contrib/test_tensorrt.py
+++ b/tests/python/contrib/test_tensorrt.py
@@ -1526,4 +1526,7 @@ def test_empty_subgraph(run_module):
 
 
 if __name__ == "__main__":
-    pytest.main([__file__])
+    import sys
+
+    # sys.exit(pytest.main([__file__] + sys.argv[1:]))
+    test_maskrcnn_resnet50(run_module)
diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py
index 1d488cc..147c420 100644
--- a/tests/python/driver/tvmc/test_compiler.py
+++ b/tests/python/driver/tvmc/test_compiler.py
@@ -523,3 +523,9 @@ def test_compile_check_configs_composite_target(mock_pkg, mock_pc, mock_fe, mock
         config={"relay.ext.mock.options": {"testopt": "value"}},
         disabled_pass=None,
     )
+
+
+if __name__ == "__main__":
+    import sys
+
+    sys.exit(pytest.main([__file__] + sys.argv[1:]))
diff --git a/tests/python/relay/test_pass_fold_constant.py b/tests/python/relay/test_pass_fold_constant.py
index 7b4eb52..3a5f458 100644
--- a/tests/python/relay/test_pass_fold_constant.py
+++ b/tests/python/relay/test_pass_fold_constant.py
@@ -22,6 +22,16 @@ from tvm.relay.build_module import bind_params_by_name
 from tvm.relay.testing import run_infer_type, create_workload
 
 
+def annot_func(f):
+    """Returns f with arg/result device attributes for the argument and result."""
+    return relay.op.annotation.function_on_device(f, [tvm.cpu()], tvm.cpu())
+
+
+def annot_expr(e):
+    """Returns e wrapped with an on_device annotation."""
+    return relay.op.annotation.on_device(e, tvm.cpu(), is_fixed=True)
+
+
 def run_opt_pass(expr, opt_pass):
     assert isinstance(opt_pass, tvm.transform.Pass)
 
@@ -75,7 +85,35 @@ def test_fold_const():
     with tvm.target.Target("cuda"):
         zz = run_opt_pass(before(), transform.FoldConstant())
     zexpected = run_opt_pass(expected(), transform.InferType())
-    assert tvm.ir.structural_equal(zz, zexpected)
+    tvm.ir.assert_structural_equal(zz, zexpected)
+
+
+def test_fold_const_with_on_device():
+    """Make sure on_device annotations don't get in the way of constant folding"""
+    c_data = np.array([1, 2, 3]).astype("float32")
+    t = relay.TensorType([1, 2, 3], "float32")
+
+    def before():
+        c = relay.const(c_data)
+        x = relay.var("x", t)
+        y = relay.add(c, c)
+        y = relay.multiply(y, relay.const(2, "float32"))
+        y = relay.add(x, y)
+        z = relay.add(y, c)
+        f = relay.Function([x], z)
+        return annot_func(f)
+
+    def expected():
+        x = relay.var("x", t)
+        c_folded = (c_data + c_data) * 2
+        y = relay.add(x, relay.const(c_folded))
+        z = relay.add(y, relay.const(c_data))
+        f = relay.Function([x], z)
+        return annot_func(f)
+
+    zz = run_opt_pass(before(), transform.FoldConstant())
+    zexpected = run_opt_pass(expected(), transform.InferType())
+    tvm.ir.assert_structural_equal(zz, zexpected)
 
 
 def test_fold_let():
@@ -101,7 +139,37 @@ def test_fold_let():
 
     zz = run_opt_pass(before(), transform.FoldConstant())
     zexpected = run_opt_pass(expected(), transform.InferType())
-    assert tvm.ir.structural_equal(zz, zexpected)
+    tvm.ir.assert_structural_equal(zz, zexpected)
+
+
+def test_fold_let_with_on_device():
+    """Make sure on_device annotations don't get in the way of constant folding,
+    and inlined constants bring their annotations with them."""
+    c_data = np.array(1).astype("float32")
+    t = relay.TensorType([1], "float32")
+
+    def before():
+        sb = relay.ScopeBuilder()
+        x = relay.var("x", t)
+        t1 = sb.let("t1", annot_expr(relay.const(c_data)))
+        t2 = sb.let("t2", annot_expr(relay.add(t1, t1)))
+        t3 = sb.let("t3", annot_expr(relay.add(t2, x)))
+        sb.ret(t3)
+        f = relay.Function([x], sb.get())
+        return annot_func(f)
+
+    def expected():
+        sb = relay.ScopeBuilder()
+        x = relay.var("x", t)
+        c_folded = c_data + c_data
+        t3 = sb.let("t3", annot_expr(relay.add(annot_expr(relay.const(c_folded)), x)))
+        sb.ret(t3)
+        f = relay.Function([x], sb.get())
+        return annot_func(f)
+
+    zz = run_opt_pass(before(), transform.FoldConstant())
+    zexpected = run_opt_pass(expected(), transform.InferType())
+    tvm.ir.assert_structural_equal(zz, zexpected)
 
 
 def test_fold_tuple():
@@ -124,7 +192,7 @@ def test_fold_tuple():
 
     zz = run_opt_pass(before(), transform.FoldConstant())
     zexpected = run_opt_pass(expected(), transform.InferType())
-    assert tvm.ir.structural_equal(zz, zexpected)
+    tvm.ir.assert_structural_equal(zz, zexpected)
 
 
 def test_fold_concat():
@@ -143,7 +211,7 @@ def test_fold_concat():
 
     zz = run_opt_pass(before(), transform.FoldConstant())
     zexpected = run_opt_pass(expected(), transform.InferType())
-    assert tvm.ir.structural_equal(zz, zexpected)
+    tvm.ir.assert_structural_equal(zz, zexpected)
 
 
 def test_fold_if():
@@ -164,7 +232,7 @@ def test_fold_if():
 
     zz = run_opt_pass(before(), transform.FoldConstant())
     zexpected = run_opt_pass(expected(), transform.InferType())
-    assert tvm.ir.structural_equal(zz, zexpected)
+    tvm.ir.assert_structural_equal(zz, zexpected)
 
     cond_data = np.array(0).astype("bool")
 
@@ -182,7 +250,7 @@ def test_fold_if():
 
     zz = run_opt_pass(before(), transform.FoldConstant())
     zexpected = run_opt_pass(expected(), transform.InferType())
-    assert tvm.ir.structural_equal(zz, zexpected)
+    tvm.ir.assert_structural_equal(zz, zexpected)
 
 
 def test_fold_shape_of():
@@ -204,7 +272,7 @@ def test_fold_shape_of():
     for dtype in ["int32", "float32"]:
         zz = run_opt_pass(before(dtype), transform.FoldConstant())
         zexpected = run_opt_pass(expected(dtype), transform.InferType())
-        assert tvm.ir.structural_equal(zz, zexpected)
+        tvm.ir.assert_structural_equal(zz, zexpected)
 
 
 def test_fold_ndarray_size():
@@ -227,7 +295,7 @@ def test_fold_ndarray_size():
     for dtype in ["int32", "float32"]:
         zz = run_opt_pass(before(dtype), transform.FoldConstant())
         zexpected = run_opt_pass(expected(dtype), transform.InferType())
-        assert tvm.ir.structural_equal(zz, zexpected)
+        tvm.ir.assert_structural_equal(zz, zexpected)
 
 
 def test_fold_batch_norm():
@@ -272,7 +340,7 @@ def test_fold_batch_norm():
         mod = remove_bn_pass(mod)
 
     expect = run_infer_type(expected())
-    assert tvm.ir.structural_equal(mod["main"], expect)
+    tvm.ir.assert_structural_equal(mod["main"], expect)
 
 
 def test_fold_dropout():
@@ -295,15 +363,11 @@ def test_fold_dropout():
     with tvm.transform.PassContext(opt_level=3):
         after_mod = passes(before_mod)
 
-    assert tvm.ir.structural_equal(run_infer_type(before_mod["main"]), after_mod["main"])
+    tvm.ir.assert_structural_equal(run_infer_type(before_mod["main"]), after_mod["main"])
 
 
 if __name__ == "__main__":
-    test_fold_const()
-    test_fold_let()
-    test_fold_tuple()
-    test_fold_concat()
-    test_fold_shape_of()
-    test_fold_batch_norm()
-    test_fold_ndarray_size()
-    test_fold_dropout()
+    import sys
+    import pytest
+
+    sys.exit(pytest.main([__file__] + sys.argv[1:]))
diff --git a/tests/python/relay/test_prng.py b/tests/python/relay/test_prng.py
index 79ed014..29e271b 100644
--- a/tests/python/relay/test_prng.py
+++ b/tests/python/relay/test_prng.py
@@ -166,7 +166,6 @@ def test_threefry_generate_out_size():
 
 
 if __name__ == "__main__":
-    test_threefry_repeatability(tvm.target.Target("llvm"), tvm.device("cpu"))
-    test_threefry_split(tvm.target.Target("llvm"), tvm.device("cpu"))
-    test_threefry_sequential_generate(tvm.target.Target("llvm"), tvm.device("cpu"))
-    test_threefry_sequential_generate_remaining(tvm.target.Target("llvm"), tvm.device("cpu"))
+    import sys
+
+    sys.exit(pytest.main([__file__] + sys.argv[1:]))
diff --git a/tests/python/unittest/test_target_codegen_vulkan.py b/tests/python/unittest/test_target_codegen_vulkan.py
index 1edc5d3..7b708cb 100644
--- a/tests/python/unittest/test_target_codegen_vulkan.py
+++ b/tests/python/unittest/test_target_codegen_vulkan.py
@@ -17,7 +17,6 @@
 
 import random
 import re
-import sys
 import threading
 
 import numpy as np
@@ -557,4 +556,6 @@ def test_shared_mem_alloc(target, dev):
 
 
 if __name__ == "__main__":
-    sys.exit(pytest.main(sys.argv))
+    import sys
+
+    sys.exit(pytest.main([__file__] + sys.argv[1:]))