You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2020/09/09 19:09:36 UTC

[incubator-tvm] branch master updated: [RELAY][REFACTOR] Mix mode context analysis (#6403)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 50adbfa  [RELAY][REFACTOR] Mix mode context analysis (#6403)
50adbfa is described below

commit 50adbfac6a533f231a930bf80617c8eb3d7097a7
Author: Zhi <51...@users.noreply.github.com>
AuthorDate: Wed Sep 9 12:09:22 2020 -0700

    [RELAY][REFACTOR] Mix mode context analysis (#6403)
    
    * mix mode context analysis
    
    * add uses_gpu decorator for more tests
    
    * revert visit counter
    
    * relax visit limit
    
    * lint
    
    * bump visit limit to 19
    
    * typo
---
 src/relay/analysis/context_analysis.cc             |  70 ++++++-------
 tests/python/frontend/pytorch/test_forward.py      |   4 +-
 .../pytorch/{lstm_test.py => test_lstm.py}         |  10 +-
 tests/python/relay/test_adt.py                     | 109 +++++++++++---------
 tests/python/relay/test_any.py                     |  75 +++++++-------
 tests/python/relay/test_vm.py                      | 112 +++++++++++++--------
 6 files changed, 212 insertions(+), 168 deletions(-)

diff --git a/src/relay/analysis/context_analysis.cc b/src/relay/analysis/context_analysis.cc
index bbea039..5fbd8a4 100644
--- a/src/relay/analysis/context_analysis.cc
+++ b/src/relay/analysis/context_analysis.cc
@@ -160,11 +160,14 @@ DeviceDomainPtr Join(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) {
  * \brief Compute on which device each sub-expression will execute. A union find
  * algorithm is used to assign and merge the context domains.
  */
-class ContextAnalyzer : public ExprVisitor {
+class ContextAnalyzer : public MixedModeVisitor {
  public:
   ContextAnalyzer(const IRModule& mod, const GlobalVar& current_func,
                   const TVMContext& default_context)
-      : mod_(mod), current_func_(current_func), default_context_(default_context) {
+      : MixedModeVisitor(9),  // the number of repeated visits a node can perform
+        mod_(mod),
+        current_func_(current_func),
+        default_context_(default_context) {
     cpu_ctx_.device_type = kDLCPU;
     cpu_ctx_.device_id = 0;
   }
@@ -295,7 +298,7 @@ class ContextAnalyzer : public ExprVisitor {
       UnifyVarCall(cn);
     } else {
       UnifyCall(call, cn->args, {call}, Bottom());
-      ExprVisitor::VisitExpr_(cn);
+      MixedModeVisitor::VisitExpr_(cn);
     }
   }
 
@@ -315,20 +318,18 @@ class ContextAnalyzer : public ExprVisitor {
       // Unify let var, value, and body
       Unify(DeviceFor(let->var), DeviceFor(let->value));
       UnifyExpr(let, let->body);
-      ExprVisitor::VisitExpr(let->value);
+      MixedModeVisitor::VisitExpr(let->value);
       expr = let->body;
     }
     // Visit the last body
-    ExprVisitor::VisitExpr(expr);
+    MixedModeVisitor::VisitExpr(expr);
   }
 
   void VisitExpr_(const FunctionNode* fn) final {
     auto func = GetRef<Function>(fn);
-    auto it = visited_.find(func);
     // No need to step into fused primitive functions as they are handled as
     // a whole.
-    if (fn->HasNonzeroAttr(attr::kPrimitive) ||
-        (it != visited_.end() && !DeviceFor(func)->IsEmptyDomain())) {
+    if (fn->HasNonzeroAttr(attr::kPrimitive)) {
       return;
     }
 
@@ -336,8 +337,7 @@ class ContextAnalyzer : public ExprVisitor {
     for (const auto& it : fn->params) {
       DeviceFor(it);
     }
-    ExprVisitor::VisitExpr(fn->body);
-    visited_.insert(func);
+    MixedModeVisitor::VisitExpr(fn->body);
   }
 
   void VisitExpr_(const TupleNode* tn) final {
@@ -350,7 +350,7 @@ class ContextAnalyzer : public ExprVisitor {
       }
       Unify(device, DeviceFor(tup));
     }
-    ExprVisitor::VisitExpr_(tn);
+    MixedModeVisitor::VisitExpr_(tn);
   }
 
   void VisitExpr_(const TupleGetItemNode* tn) final {
@@ -358,7 +358,7 @@ class ContextAnalyzer : public ExprVisitor {
 
     Unify(DeviceFor(item), DeviceFor(item->tuple));
 
-    ExprVisitor::VisitExpr_(tn);
+    MixedModeVisitor::VisitExpr_(tn);
   }
 
   void VisitExpr_(const MatchNode* mn) final {
@@ -368,7 +368,11 @@ class ContextAnalyzer : public ExprVisitor {
     for (const auto& c : m->clauses) {
       device = Unify(device, DeviceFor(c->rhs));
     }
-    ExprVisitor::VisitExpr_(mn);
+    MixedModeVisitor::VisitLeaf(mn->data);
+    for (const Clause& c : mn->clauses) {
+      this->VisitClause(c);
+      MixedModeVisitor::VisitLeaf(c->rhs);
+    }
   }
 
   void VisitExpr_(const GlobalVarNode* gvn) final { DeviceFor(GetRef<GlobalVar>(gvn)); }
@@ -465,7 +469,7 @@ class ContextAnalyzer : public ExprVisitor {
     //  same device to the source device type of the device copy op.
     //  The call itself has the same device type to the destination.
     UnifyDeviceCopy(inps, outs, src_dev_type, dst_dev_type);
-    ExprVisitor::VisitExpr_(call);
+    MixedModeVisitor::VisitExpr_(call);
   }
 
   void UnifyAllocStorageCall(const CallNode* call) {
@@ -475,7 +479,7 @@ class ContextAnalyzer : public ExprVisitor {
     // The arguments of alloc storage should be on CPU.
     for (int i = 0; i < 2; i++) {
       Unify(DeviceFor(call->args[i]), DeviceType(cpu_ctx_));
-      ExprVisitor::VisitExpr(call->args[i]);
+      MixedModeVisitor::VisitExpr(call->args[i]);
     }
     TVMContext ctx;
     const auto* attrs = call->attrs.as<AllocStorageAttrs>();
@@ -494,7 +498,7 @@ class ContextAnalyzer : public ExprVisitor {
 
     // The shape for alloc_tensor should be on CPU.
     Unify(DeviceFor(shape), DeviceType(cpu_ctx_));
-    ExprVisitor::VisitExpr(shape);
+    MixedModeVisitor::VisitExpr(shape);
   }
 
   void UnifyShapeFuncCall(const CallNode* call) {
@@ -509,11 +513,11 @@ class ContextAnalyzer : public ExprVisitor {
     Tuple outputs = Downcast<Tuple>(call->args[2]);
     UnifyCall(GetRef<Call>(call), inps->fields, outputs->fields, shape_func_domain);
     for (const auto& it : inps->fields) {
-      ExprVisitor::VisitExpr(it);
+      MixedModeVisitor::VisitExpr(it);
     }
 
     for (const auto& it : outputs->fields) {
-      ExprVisitor::VisitExpr(it);
+      MixedModeVisitor::VisitExpr(it);
     }
   }
 
@@ -523,13 +527,13 @@ class ContextAnalyzer : public ExprVisitor {
     Tuple inps = Downcast<Tuple>(call->args[1]);
     Tuple outputs = Downcast<Tuple>(call->args[2]);
     UnifyCall(call->args[0], inps->fields, outputs->fields, Bottom());
-    ExprVisitor::VisitExpr_(call);
+    MixedModeVisitor::VisitExpr_(call);
   }
 
   void UnifyShapeOfCall(const CallNode* call) {
     // vm shape_of is always on the CPU.
     CHECK_EQ(call->args.size(), 1U);
-    ExprVisitor::VisitExpr(call->args[0]);
+    MixedModeVisitor::VisitExpr(call->args[0]);
     // Note we don't unify the input of a shape_of with the cpu domain. This is
     // because vm.shape_of has a native instruction to compute the shape of
     // a tensor regardless its device type.
@@ -547,8 +551,8 @@ class ContextAnalyzer : public ExprVisitor {
 
     // The shape field of reshape_tensor is always on the CPU.
     Unify(DeviceFor(shape), DeviceType(cpu_ctx_));
-    ExprVisitor::VisitExpr(data);
-    ExprVisitor::VisitExpr(shape);
+    MixedModeVisitor::VisitExpr(data);
+    MixedModeVisitor::VisitExpr(shape);
   }
 
   void UnifyFunctionCall(const CallNode* call) {
@@ -556,7 +560,7 @@ class ContextAnalyzer : public ExprVisitor {
     // Unify the arguments of the caller.
     for (const auto& arg : call->args) {
       device = Unify(device, DeviceFor(arg));
-      ExprVisitor::VisitExpr(arg);
+      MixedModeVisitor::VisitExpr(arg);
     }
 
     // Unify the parameters of the callee.
@@ -564,7 +568,7 @@ class ContextAnalyzer : public ExprVisitor {
     Function func = Downcast<Function>(call->op);
     for (const auto& param : func->params) {
       device = Unify(device, DeviceFor(param));
-      ExprVisitor::VisitExpr(param);
+      MixedModeVisitor::VisitExpr(param);
     }
 
     // Unify the function expression and its body
@@ -573,7 +577,7 @@ class ContextAnalyzer : public ExprVisitor {
 
     // Step into the callee. It will be skipped if the callee if a primitive
     // function
-    ExprVisitor::VisitExpr(call->op);
+    MixedModeVisitor::VisitExpr(call->op);
   }
 
   // Invoke a global function.
@@ -588,7 +592,7 @@ class ContextAnalyzer : public ExprVisitor {
     for (size_t i = 0; i < call->args.size(); i++) {
       Expr arg = call->args[i];
       Expr param = func->params[i];
-      ExprVisitor::VisitExpr(arg);
+      MixedModeVisitor::VisitExpr(arg);
 
       // Save the the arg to function mapping for closures as it will
       // be invoked/unified later.
@@ -616,8 +620,7 @@ class ContextAnalyzer : public ExprVisitor {
     auto cur_func = current_func_;
     current_func_ = gv;
     if (cur_func->name_hint != gv->name_hint) {
-      ExprVisitor::VisitExpr(func);
-      visited_.insert(func);
+      MixedModeVisitor::VisitExpr(func);
     }
     // Exit the frame.
     current_func_ = cur_func;
@@ -632,7 +635,7 @@ class ContextAnalyzer : public ExprVisitor {
     auto glb_var = it->second;
     CHECK(mod_.defined()) << "Cannot analyze context on a globalvar without module";
     Function func = Downcast<Function>(mod_->Lookup(glb_var));
-    // Unify the underlying function for clousre or currying funcitons.
+    // Unify the underlying function for clousre or currying functions.
     while (IsClosure(func) || IsCurrying(func)) {
       device = Unify(device, DeviceFor(func));
       if (IsClosure(func)) {
@@ -641,14 +644,14 @@ class ContextAnalyzer : public ExprVisitor {
         Let let = Downcast<Let>(func->body);
         func = Downcast<Function>(mod_->Lookup(closures_[let->var]));
       } else {
-        LOG(FATAL) << "func is expected to be a closure or a currying funciton";
+        LOG(FATAL) << "func is expected to be a closure or a currying function";
       }
     }
 
     CHECK_EQ(call->args.size(), func->params.size());
     for (size_t i = 0; i < call->args.size(); i++) {
       Unify(DeviceFor(call->args[i]), DeviceFor(func->params[i]));
-      ExprVisitor::VisitExpr(call->args[i]);
+      MixedModeVisitor::VisitExpr(call->args[i]);
     }
     device = Unify(device, DeviceFor(call->op));
     device = Unify(device, DeviceFor(glb_var));
@@ -658,8 +661,7 @@ class ContextAnalyzer : public ExprVisitor {
     auto cur_func = current_func_;
     current_func_ = glb_var;
     if (cur_func->name_hint != glb_var->name_hint) {
-      ExprVisitor::VisitExpr(func);
-      visited_.insert(func);
+      MixedModeVisitor::VisitExpr(func);
     }
     current_func_ = cur_func;
   }
@@ -684,8 +686,6 @@ class ContextAnalyzer : public ExprVisitor {
    * will be invoked lazily.
    */
   std::unordered_map<Expr, GlobalVar, runtime::ObjectPtrHash, runtime::ObjectPtrEqual> closures_;
-  /* \brief Cache the visited functions. */
-  std::unordered_set<Expr, runtime::ObjectPtrHash, runtime::ObjectPtrEqual> visited_;
 };
 
 }  // namespace analysis
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index e651700..fe14c91 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -3198,9 +3198,9 @@ if __name__ == "__main__":
     test_simple_rnn()
 
     # More complex recurrent models
-    from lstm_test import custom_lstm_test
+    from lstm_test import test_custom_lstm
 
-    custom_lstm_test()
+    test_custom_lstm()
 
     # Test bert model
     test_forward_pretrained_bert_base_uncased()
diff --git a/tests/python/frontend/pytorch/lstm_test.py b/tests/python/frontend/pytorch/test_lstm.py
similarity index 97%
rename from tests/python/frontend/pytorch/lstm_test.py
rename to tests/python/frontend/pytorch/test_lstm.py
index 4616698..4524a72 100644
--- a/tests/python/frontend/pytorch/lstm_test.py
+++ b/tests/python/frontend/pytorch/test_lstm.py
@@ -215,8 +215,8 @@ def assert_equal(tvm_result, torch_result):
                                     rtol=1e-4, atol=1e-4)
 
 
-def run_and_compare(mod, params, pt_result):
-    executor = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0), target="llvm")
+def run_and_compare(mod, params, pt_result, target, ctx):
+    executor = relay.create_executor("vm", mod=mod, ctx=ctx, target=target)
     evaluator = executor.evaluate()
     exec_res = evaluator(**params)
 
@@ -258,7 +258,8 @@ def convert_list_to_vmobj(py_lst):
     return adt_lst
 
 
-def custom_lstm_test():
+@tvm.testing.uses_gpu
+def test_custom_lstm():
     input_name = "input"
     states_name = "states"
     seq_len = 5
@@ -332,4 +333,5 @@ def custom_lstm_test():
         else:
             params[states_name] = states_np
 
-        run_and_compare(mod, params, pt_result)
+        for tgt, ctx in tvm.testing.enabled_targets():
+            run_and_compare(mod, params, pt_result, target=tgt, ctx=ctx)
diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py
index d0e0105..da72429 100644
--- a/tests/python/relay/test_adt.py
+++ b/tests/python/relay/test_adt.py
@@ -152,11 +152,13 @@ def get_scalar(tv):
     return tv.asnumpy().item()
 
 
+@tvm.testing.uses_gpu
 def test_nat_value():
     assert count(make_nat_value(p, 10)) == 10
     assert count(intrp.evaluate(s(s(z())))) == 2
 
 
+@tvm.testing.uses_gpu
 def test_nat_constructor():
     func = relay.Function([], z())
     test_z = relay.GlobalVar("test_z")
@@ -168,24 +170,29 @@ def test_nat_constructor():
     assert mod[test_sz].body.checked_type == nat()
 
 
+@tvm.testing.uses_gpu
 def test_double():
     assert mod[double].checked_type == relay.FuncType([nat()], nat())
     res = intrp.evaluate(double(s(z())))
     assert count(res) == 2
 
 
+@tvm.testing.uses_gpu
 def test_add():
     assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat())
     res = intrp.evaluate(add(s(z()), s(z())))
     assert count(res) == 2
 
 
+@tvm.testing.uses_gpu
 def test_list_constructor():
     test_consz = relay.GlobalVar("test_consz")
     func = relay.Function([], cons(z(), nil()))
     mod[test_consz] = func
     assert mod[test_consz].body.checked_type == l(nat())
 
+
+@tvm.testing.uses_gpu
 def test_hd_tl():
     expected = list(range(10))
     l = nil()
@@ -199,6 +206,8 @@ def test_hd_tl():
 
     assert got == expected
 
+
+@tvm.testing.uses_gpu
 def test_nth():
     expected = list(range(10))
     l = nil()
@@ -210,6 +219,7 @@ def test_nth():
         assert get_scalar(item) == i
 
 
+@tvm.testing.uses_gpu
 def test_update():
     expected = list(range(10))
     l = nil()
@@ -227,6 +237,8 @@ def test_update():
 
     assert got == expected
 
+
+@tvm.testing.uses_gpu
 def test_length():
     a = relay.TypeVar("a")
     assert mod[length].checked_type == relay.FuncType([l(a)], relay.scalar_type('int32'), [a])
@@ -234,6 +246,7 @@ def test_length():
     assert get_scalar(res) == 3
 
 
+@tvm.testing.uses_gpu
 def test_map():
     a = relay.TypeVar("a")
     b = relay.TypeVar("b")
@@ -249,6 +262,7 @@ def test_map():
     assert count(ones[0]) == 1 and count(ones[1]) == 1
 
 
+@tvm.testing.uses_gpu
 def test_foldl():
     a = relay.TypeVar("a")
     b = relay.TypeVar("b")
@@ -270,6 +284,7 @@ def test_foldl():
     assert count(reversed[4]) == 1 and count(reversed[5]) == 1
 
 
+@tvm.testing.uses_gpu
 def test_foldr():
     a = relay.TypeVar("a")
     b = relay.TypeVar("b")
@@ -289,6 +304,7 @@ def test_foldr():
     assert count(same[0]) == 1 and count(same[1]) == 2 and count(same[2]) == 3
 
 
+@tvm.testing.uses_gpu
 def test_foldr1():
     a = relay.TypeVar("a")
     lhs = mod[p.foldr1].checked_type
@@ -306,12 +322,14 @@ def test_foldr1():
     assert count(res) == 6
 
 
+@tvm.testing.uses_gpu
 def test_sum():
     assert mod[sum].checked_type == relay.FuncType([l(relay.scalar_type('int32'))], relay.scalar_type('int32'))
     res = intrp.evaluate(sum(cons(relay.const(1), cons(relay.const(2), nil()))))
     assert get_scalar(res) == 3
 
 
+@tvm.testing.uses_gpu
 def test_concat():
     a = relay.TypeVar("a")
     assert mod[concat].checked_type == relay.FuncType([l(a), l(a)], l(a), [a])
@@ -328,6 +346,7 @@ def test_concat():
     assert count(catted[3]) == 4
 
 
+@tvm.testing.uses_gpu
 def test_filter():
     a = relay.TypeVar("a")
     expected_type = relay.FuncType([
@@ -362,6 +381,7 @@ def test_filter():
     assert count(filtered[1]) == 5
 
 
+@tvm.testing.uses_gpu
 def test_zip():
     a = relay.TypeVar("a")
     b = relay.TypeVar("b")
@@ -403,6 +423,7 @@ def test_zip():
     assert len(to_list(singleton[0][1])) == 0
 
 
+@tvm.testing.uses_gpu
 def test_rev():
     a = relay.TypeVar("a")
     assert mod[rev].checked_type == relay.FuncType([l(a)], l(a), [a])
@@ -418,6 +439,7 @@ def test_rev():
     assert count(reversed[2]) == 1
 
 
+@tvm.testing.uses_gpu
 def test_unfoldr():
     a = relay.TypeVar("a")
     b = relay.TypeVar("b")
@@ -445,6 +467,7 @@ def test_unfoldr():
     assert count(unfolded[2]) == 1
 
 
+@tvm.testing.uses_gpu
 def test_unfoldl():
     a = relay.TypeVar("a")
     b = relay.TypeVar("b")
@@ -472,6 +495,7 @@ def test_unfoldl():
     assert count(unfolded[2]) == 3
 
 
+@tvm.testing.uses_gpu
 def test_map_accumr():
     a = relay.TypeVar("a")
     b = relay.TypeVar("b")
@@ -501,6 +525,7 @@ def test_map_accumr():
     assert count(new_vals[2]) == 3
 
 
+@tvm.testing.uses_gpu
 def test_map_accuml():
     a = relay.TypeVar("a")
     b = relay.TypeVar("b")
@@ -529,6 +554,7 @@ def test_map_accuml():
     assert count(new_vals[2]) == 1
 
 
+@tvm.testing.uses_gpu
 def test_optional_matching():
     x = relay.Var('x')
     y = relay.Var('y')
@@ -550,6 +576,7 @@ def test_optional_matching():
     assert count(reduced[1]) == 1
 
 
+@tvm.testing.uses_gpu
 def test_tmap():
     a = relay.TypeVar("a")
     b = relay.TypeVar("b")
@@ -573,6 +600,7 @@ def test_tmap():
         assert len(subtree['children']) == 0
 
 
+@tvm.testing.uses_gpu
 def test_size():
     a = relay.TypeVar("a")
     lhs = mod[size].checked_type
@@ -587,6 +615,7 @@ def test_size():
     assert get_scalar(res) == 10
 
 
+@tvm.testing.uses_gpu
 def test_wildcard_match_solo():
     x = relay.Var('x', nat())
     copy = relay.Function([x],
@@ -597,6 +626,7 @@ def test_wildcard_match_solo():
     assert count(res) == 3
 
 
+@tvm.testing.uses_gpu
 def test_wildcard_match_order():
     x = relay.Var('x', l(nat()))
     y = relay.Var('y')
@@ -618,6 +648,7 @@ def test_wildcard_match_order():
     assert count(res) == 0
 
 
+@tvm.testing.uses_gpu
 def test_nested_matches():
     a = relay.TypeVar('a')
     x = relay.Var('x')
@@ -659,6 +690,7 @@ def test_nested_matches():
         assert count(flat[i]) == i + 1
 
 
+@tvm.testing.uses_gpu
 def test_match_full_var():
     x = relay.Var('x')
     v = relay.Var('v')
@@ -679,6 +711,7 @@ def test_match_full_var():
     assert count(zeroes[1]) == 0
 
 
+@tvm.testing.uses_gpu
 def test_nested_pattern_match():
     x = relay.Var('x', l(nat()))
     h1 = relay.Var('h1')
@@ -705,6 +738,7 @@ def test_nested_pattern_match():
     assert count(res) == 2
 
 
+@tvm.testing.uses_gpu
 def test_compose():
     n = relay.Var('n')
     inc = relay.Function([n], s(n))
@@ -713,6 +747,7 @@ def test_compose():
     assert count(res) == 5
 
 
+@tvm.testing.uses_gpu
 def test_iterate():
     expr = relay.Call(iterate(double, relay.const(2)), [make_nat_expr(3)])
     res = intrp.evaluate(relay.Function([], expr)())
@@ -730,6 +765,7 @@ def check_tensor_array(ta_mod, ref_res, *args, dtype="float32", rtol=1e-5):
             tvm.testing.assert_allclose(ref_res, got, rtol=rtol, atol=rtol)
 
 
+@tvm.testing.uses_gpu
 def test_tensor_expand_dims():
     def run(dtype):
         x = relay.var('x')
@@ -745,6 +781,7 @@ def test_tensor_expand_dims():
     run('int32')
 
 
+@tvm.testing.uses_gpu
 def test_tensor_array_constructor():
     def run(dtype):
         x = relay.var('x')
@@ -758,6 +795,7 @@ def test_tensor_array_constructor():
     run('int32')
 
 
+@tvm.testing.uses_gpu
 def test_tensor_array_read():
     def run(dtype):
         mod = tvm.IRModule()
@@ -774,6 +812,7 @@ def test_tensor_array_read():
     run('int32')
 
 
+@tvm.testing.uses_gpu
 def test_tensor_array_write():
     def run(dtype):
         mod = tvm.IRModule()
@@ -794,6 +833,7 @@ def test_tensor_array_write():
     run('int32')
 
 
+@tvm.testing.uses_gpu
 def test_tensor_array_stack():
     def run(dtype):
         mod = tvm.IRModule()
@@ -816,6 +856,7 @@ def test_tensor_array_stack():
     run('int32')
 
 
+@tvm.testing.uses_gpu
 def test_tensor_array_unstack():
     def run(dtype):
         mod = tvm.IRModule()
@@ -829,6 +870,7 @@ def test_tensor_array_unstack():
     run('int32')
 
 
+@tvm.testing.uses_gpu
 def test_tensor_take():
     def run(dtype):
         mod = tvm.IRModule()
@@ -848,6 +890,7 @@ def test_tensor_take():
     run('int32')
 
 
+@tvm.testing.uses_gpu
 def test_tensor_concatenate():
     def run(dtype):
         mod = tvm.IRModule()
@@ -866,6 +909,7 @@ def test_tensor_concatenate():
     run('int32')
 
 
+@tvm.testing.uses_gpu
 def test_tensor_array_concat():
     def run(dtype):
         mod = tvm.IRModule()
@@ -889,6 +933,7 @@ def test_tensor_array_concat():
     run('int32')
 
 
+@tvm.testing.uses_gpu
 def test_tensor_array_scatter():
     def run(dtype):
         mod = tvm.IRModule()
@@ -939,6 +984,7 @@ def test_tensor_array_scatter():
     run('int32')
 
 
+@tvm.testing.uses_gpu
 def test_tensor_array_split():
     def run(dtype):
         mod = tvm.IRModule()
@@ -982,6 +1028,8 @@ def test_tensor_array_split():
     run('float32')
     run('int32')
 
+
+@tvm.testing.uses_gpu
 def test_static_tensor_take():
     def run(dtype, shape):
         mod = tvm.IRModule()
@@ -1004,6 +1052,7 @@ def test_static_tensor_take():
     run('int32', [15, 11])
 
 
+@tvm.testing.uses_gpu
 def test_static_tensor_concatenate():
     def run(dtype, shape):
         mod = tvm.IRModule()
@@ -1025,6 +1074,7 @@ def test_static_tensor_concatenate():
     run('int32', [2, 3])
 
 
+@tvm.testing.uses_gpu
 def test_static_tensor_expand_dims():
     def run(dtype, shape):
         x = relay.var('x')
@@ -1043,6 +1093,7 @@ def test_static_tensor_expand_dims():
     run('int32', [2,])
 
 
+@tvm.testing.uses_gpu
 def test_static_tensor_array_constructor():
     def run(dtype, shape):
         mod = tvm.IRModule()
@@ -1054,6 +1105,7 @@ def test_static_tensor_array_constructor():
     run('float32', [1, 1])
 
 
+@tvm.testing.uses_gpu
 def test_static_tensor_array_read():
     def run(dtype, shape):
         mod = tvm.IRModule()
@@ -1093,6 +1145,7 @@ def test_static_tensor_array_read():
     run('int32', [2, 3])
 
 
+@tvm.testing.uses_gpu
 def test_static_tensor_array_write():
     def run(dtype, shape):
         mod = tvm.IRModule()
@@ -1119,6 +1172,7 @@ def test_static_tensor_array_write():
     run('int32', [2, 3])
 
 
+@tvm.testing.uses_gpu
 def test_static_tensor_array_unstack():
     def run(dtype, shape):
         mod = tvm.IRModule()
@@ -1136,6 +1190,7 @@ def test_static_tensor_array_unstack():
     run('int32', [2, 3])
 
 
+@tvm.testing.uses_gpu
 def test_static_tensor_array_scatter():
     def run(dtype, shape, indices_shape=None):
         mod = tvm.IRModule()
@@ -1191,6 +1246,7 @@ def test_static_tensor_array_scatter():
     run('float32', [2, 3], [2,])
 
 
+@tvm.testing.uses_gpu
 def test_static_tensor_array_split():
     def run(dtype, shape, value_shape=None, lengths_shape=None):
         mod = tvm.IRModule()
@@ -1254,6 +1310,7 @@ def test_static_tensor_array_split():
     run('int32', [relay.Any(), 3], [4, 3], [2,])
 
 
+@tvm.testing.uses_gpu
 def test_static_tensor_array_concat():
     def run(dtype, shape):
         mod = tvm.IRModule()
@@ -1280,6 +1337,7 @@ def test_static_tensor_array_concat():
     run('int32', [relay.Any(), 3])
 
 
+@tvm.testing.uses_gpu
 def test_static_tensor_array_gather():
     def run(dtype, shape):
         mod = tvm.IRModule()
@@ -1307,6 +1365,7 @@ def test_static_tensor_array_gather():
     run('int32', [2, 3])
 
 
+@tvm.testing.uses_gpu
 def test_static_tensor_array_stack():
     def run(dtype, shape):
         mod = tvm.IRModule()
@@ -1332,6 +1391,7 @@ def test_static_tensor_array_stack():
     run('int32', [2, 3])
 
 
+@tvm.testing.uses_gpu
 def test_static_tensor_get_data():
     def run(dtype, shape):
         mod = tvm.IRModule()
@@ -1372,51 +1432,4 @@ def test_static_tensor_get_data():
     run('int32', [2, 3])
 
 if __name__ == "__main__":
-    test_nat_constructor()
-    test_double()
-    test_add()
-    test_list_constructor()
-    test_length()
-    test_map()
-    test_foldl()
-    test_foldr()
-    test_foldr1()
-    test_concat()
-    test_filter()
-    test_zip()
-    test_rev()
-    test_unfoldl()
-    test_unfoldr()
-    test_map_accumr()
-    test_map_accuml()
-    test_sum()
-    test_tmap()
-    test_size()
-    test_compose()
-    test_iterate()
-
-    test_tensor_expand_dims()
-    test_tensor_array_constructor()
-    test_tensor_array_read()
-    test_tensor_array_write()
-    test_tensor_array_stack()
-    test_tensor_array_unstack()
-    test_tensor_take()
-    test_tensor_concatenate()
-    test_tensor_array_concat()
-    test_tensor_array_scatter()
-    test_tensor_array_split()
-
-    test_static_tensor_take()
-    test_static_tensor_concatenate()
-    test_static_tensor_expand_dims()
-    test_static_tensor_array_constructor()
-    test_static_tensor_array_read()
-    test_static_tensor_array_write()
-    test_static_tensor_array_unstack()
-    test_static_tensor_array_scatter()
-    test_static_tensor_array_split()
-    test_static_tensor_array_concat()
-    test_static_tensor_array_stack()
-    test_static_tensor_array_gather()
-    test_static_tensor_get_data()
+    pytest.main([__file__])
diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py
index e33e267..5efe08b 100644
--- a/tests/python/relay/test_any.py
+++ b/tests/python/relay/test_any.py
@@ -65,6 +65,7 @@ def verify_any_broadcast(x_shape, y_shape, x_np_shape, y_np_shape, op, np_op):
     res_np = np_op(x_np, y_np)
     check_result([x_np, y_np], mod, res_np)
 
+@tvm.testing.uses_gpu
 def test_any_broadcast():
     # Test broadcast with 1s
     verify_any_broadcast((relay.Any(),), (3, 2), (1,), (3, 2), relay.add, np.add)
@@ -86,11 +87,13 @@ def verify_any_elemwise(x_shape, x_np_shape, op, np_op):
     res_np = np_op(x_np)
     check_result([x_np], mod, res_np)
 
+@tvm.testing.uses_gpu
 def test_any_elemwise():
     verify_any_elemwise((relay.Any(),), (3,), relay.sqrt, np.sqrt)
     verify_any_elemwise((relay.Any(), 2), (5, 2), relay.negative, np.negative)
     verify_any_elemwise((relay.Any(), relay.Any()), (5, 4), relay.exp, np.exp)
 
+@tvm.testing.uses_gpu
 def test_any_broadcast_fail():
     # Test broadcast with incompatible values at runtime
     def check_fail(x_shape, y_shape, x_np_shape, y_np_shape, op, np_op):
@@ -117,6 +120,7 @@ def verify_any_full_like(x_shape, x_np_shape, relay_op, np_op, dtype='float32'):
     res_np = np_op(x_np)
     check_result([x_np], mod, res_np)
 
+@tvm.testing.uses_gpu
 def test_any_full_like():
     # zeros_like, ones_like
     verify_any_full_like(any_dims(3), (2, 3, 5), relay.zeros_like, np.zeros_like, "float32")
@@ -135,6 +139,7 @@ def verify_any_full(x_np_shape, relay_op, np_op, dtype='float32', value=None):
     x_np = np.array(x_np_shape).astype("int32")
     check_result([x_np], mod, res_np)
 
+@tvm.testing.uses_gpu
 def test_any_full():
     # zeros, ones, full
     verify_any_full((2, 3, 5), relay.zeros, np.zeros, "float32")
@@ -146,6 +151,7 @@ def test_any_full():
     verify_any_full((10, 11, 12, 13, 14), relay.full, np.full, "float32", 2.0)
     verify_any_full((1, 2, 3, 4), relay.full, np.full, "int32", -2)
 
+@tvm.testing.uses_gpu
 def test_any_concat():
     x = relay.var('x', shape=(relay.Any(), 2), dtype="float32")
     y = relay.var('y', shape=(1, 2), dtype="float32")
@@ -177,6 +183,7 @@ def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape, variable_newsha
     mod["main"] = relay.Function(params, y)
     check_result(args, mod, data, flatten=True)
 
+@tvm.testing.uses_gpu
 def test_any_reshape():
     for variable_newshape in [False, True]:
         # Variable newshape only supports that output rank is the same as newshape
@@ -202,6 +209,7 @@ def verify_any_argwhere(x_shape, x_np_shape, dtype="bool"):
     # TODO(@zhiics) argwhere gpu schedule is currently not avaiable
     # check_result([data], mod, expected, flatten=True)
 
+@tvm.testing.uses_gpu
 def test_any_argwhere():
     verify_any_argwhere(any_dims(1), (5,))
     verify_any_argwhere(any_dims(2), (5, 5))
@@ -234,6 +242,7 @@ def verify_any_take(data_shape, indices_shape, axis, data_np_shape, indices_np_s
     ref = np.take(data_np, indices_np, axis=axis)
     check_result([data_np, indices_np], mod, ref)
 
+@tvm.testing.uses_gpu
 def test_any_take():
     verify_any_take(any_dims(2), (1,), 0, (4, 5), (1,))
     verify_any_take(any_dims(2), (), 0, (4, 5), ())
@@ -251,12 +260,14 @@ def verify_any_tile(dshape, reps, np_dshape, np_reps):
     ref_res = np.tile(x_data, reps=np_reps)
     check_result([x_data], mod, ref_res)
 
+@tvm.testing.uses_gpu
 def test_any_tile():
     verify_any_tile(any_dims(3), (3, 2, 1), (2, 3, 4), (3, 2, 1))
     verify_any_tile(any_dims(3), (1, 2), (2, 3, 4), (1, 2))
     verify_any_tile(any_dims(2), (3, 2, 1), (2, 3), (3, 2, 1))
     verify_any_tile(any_dims(3), (1,), (2, 3, 4), (1,))
 
+@tvm.testing.uses_gpu
 def test_any_shape_of():
     x = relay.var('x', shape=any_dims(2), dtype='float32')
     y = relay.shape_of(x)
@@ -283,6 +294,7 @@ def verify_any_reduce(reduce_op, data_shape, axis, exclude, keepdims,
     data_np = np.random.uniform(size=static_data_shape).astype(dtype)
     check_result([data_np], mod, ref_out_shape, assert_shape=True)
 
+@tvm.testing.uses_gpu
 def test_any_reduce():
     verify_any_reduce(relay.argmax, any_dims(3), None, False, False, (3, 4, 5), ())
     verify_any_reduce(relay.argmin, any_dims(4), 1, False, True, (3, 4, 5, 6), (3, 1, 5, 6))
@@ -302,6 +314,7 @@ def verify_any_layout_transform(data_shape, src_layout, dst_layout, static_data_
     data_np = np.random.uniform(size=static_data_shape).astype(dtype)
     check_result([data_np], mod, ref_out_shape, assert_shape=True)
 
+@tvm.testing.uses_gpu
 def test_any_layout_transform():
     verify_any_layout_transform(any_dims(4), "NCHW", "NHWC", (3, 4, 5, 6), (3, 5, 6, 4))
     verify_any_layout_transform(any_dims(5), "NCHW16c", "NCHW2c", (1, 2, 8, 8, 16), (1, 16, 8, 8, 2))
@@ -318,6 +331,7 @@ def verify_any_expand_dims(data_shape, axis, num_newaxis, static_data_shape, ref
     data_np = np.random.uniform(size=static_data_shape).astype(dtype)
     check_result([data_np], mod, ref_out_shape, assert_shape=True)
 
+@tvm.testing.uses_gpu
 def test_any_expand_dims():
     verify_any_expand_dims(any_dims(3), 1, 2, (1, 2, 3), (1, 1, 1, 2, 3))
     verify_any_expand_dims(any_dims(3), -1, 2, (1, 2, 3), (1, 2, 3, 1, 1))
@@ -332,6 +346,7 @@ def verify_any_transpose(data_shape, axes, static_data_shape):
     ref_out = np.transpose(data_np, axes)
     check_result([data_np], mod, ref_out)
 
+@tvm.testing.uses_gpu
 def test_any_transpose():
     verify_any_transpose(any_dims(3), (1, 0, 2), (10, 3, 2))
     verify_any_transpose(any_dims(3), None, (2, 3, 4))
@@ -348,10 +363,12 @@ def verify_any_squeeze(data_shape, axis, static_data_shape):
     ref_out = np.squeeze(data_np, axis)
     check_result([data_np], mod, ref_out)
 
+@tvm.testing.uses_gpu
 def test_any_squeeze():
     verify_any_squeeze((1, relay.Any(), relay.Any()), (0,), (1, 9, 8))
     verify_any_squeeze((1, relay.Any(), relay.Any(), 1, relay.Any(), relay.Any()), (0, 3), (1, 12, 2, 1, 9, 17))
 
+@tvm.testing.uses_gpu
 def test_any_reshape_like():
     mod = tvm.IRModule()
     dtype = "float32"
@@ -399,6 +416,7 @@ def verify_any_pool2d(pool_type, data_shape, pool_size, strides, padding,
     data_np = np.random.uniform(size=static_data_shape).astype(dtype)
     check_result([data_np], mod, ref_out_shape, assert_shape=True)
 
+@tvm.testing.uses_gpu
 def test_any_pool2d():
     verify_any_pool2d("max", (relay.Any(), 3, relay.Any(), relay.Any()),
                       (3, 3), (1, 1), (1, 1), "NCHW", (2, 3, 220, 220), (2, 3, 220, 220))
@@ -417,6 +435,7 @@ def verify_any_global_pool2d(pool_type, data_shape, layout, static_data_shape, r
     data_np = np.random.uniform(size=static_data_shape).astype(dtype)
     check_result([data_np], mod, ref_out_shape, assert_shape=True)
 
+@tvm.testing.uses_gpu
 def test_any_global_pool2d():
     verify_any_global_pool2d("max", (relay.Any(), 3, relay.Any(), relay.Any()),
                       "NCHW", (2, 3, 220, 220), (2, 3, 1, 1))
@@ -439,12 +458,14 @@ def verify_any_split(data_shape, indices_or_sections, axis, static_data_shape, r
             assert ret.asnumpy().shape == ref_ret, \
                 "Shape mismatch: expect %s but got %s." % (str(ref_ret), str(ret.asnumpy().shape))
 
+@tvm.testing.uses_gpu
 def test_any_split():
     verify_any_split((relay.Any(), 4), 2, 1, (9, 4), [(9, 2), (9, 2)])
     verify_any_split((relay.Any(), relay.Any()), 2, 1, (9, 4), [(9, 2), (9, 2)])
     verify_any_split((relay.Any(), 12), (1, 4, 8), 1, (7, 12), [(7, 1), (7, 3), (7, 4)])
     verify_any_split((relay.Any(), relay.Any()), (1, 4, 8), 1, (7, 12), [(7, 1), (7, 3), (7, 4)])
 
+@tvm.testing.uses_gpu
 def test_any_batch_flatten():
     mod = tvm.IRModule()
     dtype = "float32"
@@ -467,10 +488,13 @@ def verify_any_dense(data_shape, weight_shape, units, static_data_shape,
     weight_np = np.random.uniform(size=static_weight_shape).astype(dtype)
     check_result([data_np, weight_np], mod, ref_out_shape, assert_shape=True)
 
+# TODO(tvm-team) Fix dense schedule
+# @tvm.testing.uses_gpu
 def test_any_dense():
     verify_any_dense(any_dims(2), any_dims(2), None, (4, 16), (8, 16), (4, 8))
     verify_any_dense(any_dims(2), (50, relay.Any()), 50, (4, 40), (50, 40), (4, 50))
 
+@tvm.testing.uses_gpu
 def verify_any_pad(data_shape, pad_width, static_data_shape):
     mod = tvm.IRModule()
     dtype = "float32"
@@ -481,6 +505,7 @@ def verify_any_pad(data_shape, pad_width, static_data_shape):
     ref_out = np.pad(data_np, pad_width)
     check_result([data_np], mod, ref_out)
 
+@tvm.testing.uses_gpu
 def test_any_pad():
     verify_any_pad(any_dims(3), ((0, 0), (1, 1), (2, 2)), (1, 2, 3))
     verify_any_pad(any_dims(4), ((1, 0), (1, 3), (0, 2), (9, 0)), (13, 11, 3, 1))
@@ -499,6 +524,7 @@ def verify_any_dilate(data_shape, strides, static_data_shape):
     ref_out[tuple(slice(None, None, strides[i]) for i in range(len(data_shape)))] = data_np
     check_result([data_np], mod, ref_out)
 
+@tvm.testing.uses_gpu
 def test_any_dilate():
     verify_any_dilate(any_dims(1), (1,), (1,))
     verify_any_dilate(any_dims(1), (1,), (5,))
@@ -518,6 +544,7 @@ def verify_any_softmax(data_shape, axis, static_data_shape, ref_out_shape):
     data_np = np.random.uniform(size=static_data_shape).astype(dtype)
     check_result([data_np], mod, ref_out_shape, assert_shape=True)
 
+@tvm.testing.uses_gpu
 def test_any_softmax():
     verify_any_softmax(any_dims(3), -1, (1, 2, 3), (1, 2, 3))
     verify_any_softmax(any_dims(4), 2, (13, 11, 3, 1), (13, 11, 3, 1))
@@ -556,6 +583,7 @@ def test_any_topk():
     verify_any_topk(any_dims(2), 2, (6, 3), "int32")
     verify_any_topk(any_dims(2), 3, (6, 3), "float32", True)
 
+@tvm.testing.uses_gpu
 def test_fused_ops():
     x = relay.var('x', shape=(relay.Any(), relay.Any()), dtype='float32')
     y0 = x + relay.const(1.0, 'float32')
@@ -565,6 +593,7 @@ def test_fused_ops():
     data = np.random.uniform(size=(5, 4)).astype('float32')
     check_result([data], mod, (data + 1) * 2)
 
+@tvm.testing.uses_gpu
 def test_arange_with_dynamic_shape():
     # m, n, k = relay.ShapeVar('m'), relay.ShapeVar('n'), relay.ShapeVar('k')
     m, n, k = relay.Any(), relay.Any(), relay.Any()
@@ -611,6 +640,7 @@ def verify_any_strided_slice(data_shape, begin_shape, end_shape, strides_shape,
 
     check_result(np_inputs, mod, ref_res)
 
+@tvm.testing.uses_gpu
 def test_any_strided_slice():
     verify_any_strided_slice(any_dims(2), (2,), (2,), (2,), (15, 21))
     verify_any_strided_slice(any_dims(3), (3,), (3,), (3,), (15, 17, 21))
@@ -619,7 +649,7 @@ def test_any_strided_slice():
     verify_any_strided_slice(any_dims(3), (3,), (3,), (3,), (15, 17, 21), slice_mode="size")
     verify_any_strided_slice(any_dims(2), (2,), (2,), (2,), (15, 21), const_attrs=True)
 
-
+@tvm.testing.uses_gpu
 def test_recursive_concat():
     """
     fn @concat_loop(%i: int32, %st: (any, 1)) -> (any, 1) {
@@ -654,6 +684,7 @@ def test_recursive_concat():
     ref = np.array([0] + list(range(10))).reshape((11, 1)).astype("int32")
     check_result([data], mod, ref)
 
+@tvm.testing.uses_gpu
 def test_recursive_concat_with_wrong_annotation():
     """
     v0.0.1
@@ -701,6 +732,7 @@ def test_recursive_concat_with_wrong_annotation():
     except Exception as e:
         assert "in particular dimension 0 conflicts 2 does not match 1" in str(e)
 
+@tvm.testing.uses_gpu
 def test_tuple_get_item():
     mod = tvm.IRModule()
     dtype = "float32"
@@ -716,6 +748,7 @@ def test_tuple_get_item():
     ref_out_shape = (9, 2)
     check_result([data_np], mod, ref_out_shape, assert_shape=True)
 
+@tvm.testing.uses_gpu
 def test_mixed_input_type():
     mod = tvm.IRModule()
     dtype = "float32"
@@ -750,6 +783,7 @@ def verify_any_crop_and_resize(data_shape, boxes_shape, box_indices_shape, crop_
     box_indices_np = np.random.uniform(size=static_box_indices_shape).astype(indices_dtype)    
     check_result([data_np, boxes_np, box_indices_np], mod, ref_out_shape, assert_shape=True)
 
+@tvm.testing.uses_gpu
 def test_any_crop_and_resize():
     verify_any_crop_and_resize(
         data_shape=(1, 234, 234, 256),
@@ -780,6 +814,7 @@ def verify_any_mirror_pad(data_shape, pad_width, static_data_shape, ref_out_shap
     data_np = np.random.uniform(size=static_data_shape).astype(dtype)
     check_result([data_np], mod, ref_out_shape, assert_shape=True)
 
+@tvm.testing.uses_gpu
 def test_any_mirror_pad():
     verify_any_mirror_pad(
         data_shape=(1, 256, 232, 232),
@@ -796,45 +831,11 @@ def verify_any_ndarray_size(data_np_shape):
     ref_res = np.size(np_data)
     check_result([np_data], mod, ref_res)
 
+@tvm.testing.uses_gpu
 def test_any_ndarray_size():
     verify_any_ndarray_size((2,))
     verify_any_ndarray_size((2, 2))
     verify_any_ndarray_size((1, 2, 3, 4))
 
 if __name__ == "__main__":
-    test_any_full()
-    test_any_full_like()
-    test_any_broadcast()
-    test_any_elemwise()
-    test_any_broadcast_fail()
-    test_any_concat()
-    test_any_reshape()
-    test_any_take()
-    test_any_tile()
-    test_any_split()
-    test_any_shape_of()
-    test_any_reduce()
-    test_any_layout_transform()
-    test_any_expand_dims()
-    test_any_transpose()
-    test_any_squeeze()
-    test_any_reshape_like()
-    test_any_conv2d_NCHWc()
-    test_any_pool2d()
-    test_any_global_pool2d()
-    test_any_batch_flatten()
-    test_any_dense()
-    test_any_pad()
-    test_any_softmax()
-    test_any_topk()
-    test_fused_ops()
-    test_any_argwhere()
-    test_arange_with_dynamic_shape()
-    test_any_strided_slice()
-    test_recursive_concat()
-    test_recursive_concat_with_wrong_annotation()
-    test_tuple_get_item()
-    test_mixed_input_type()
-    test_any_crop_and_resize()
-    test_any_mirror_pad()
-    test_any_ndarray_size()
+    pytest.main([__file__])
diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py
index 710025a..1e8069b 100644
--- a/tests/python/relay/test_vm.py
+++ b/tests/python/relay/test_vm.py
@@ -39,11 +39,7 @@ def check_result(args, expected_result, mod=None):
     expected_result:
         The expected result of running the expression.
     """
-    # TODO(@zhiics, @icemelon9): Disable the gpu test for now until the heterogeneous support
-    #   is ready
     for target, ctx in tvm.testing.enabled_targets():
-        if "cuda" in target:
-            continue
         vm = relay.create_executor('vm', ctx=ctx, target=target, mod=mod)
         rts_result = vm.evaluate()(*args)
         tvm.testing.assert_allclose(expected_result, rts_result.asnumpy())
@@ -70,17 +66,20 @@ def vmobj_to_list(o):
     else:
         raise RuntimeError("Unknown object type: %s" % type(o))
 
+@tvm.testing.uses_gpu
 def test_split():
     x = relay.var('x', shape=(12,))
     y = relay.split(x, 3, axis=0).astuple()
     f = relay.Function([x], y)
 
     x_data = np.random.rand(12,).astype('float32')
-    res = veval(f, x_data)
     ref_res = np.split(x_data, 3, axis=0)
-    for i in range(3):
-        tvm.testing.assert_allclose(res[i].asnumpy(), ref_res[i])
+    for tgt, ctx in tvm.testing.enabled_targets():
+        res = veval(f, x_data, ctx=ctx, target=tgt)
+        for i in range(3):
+            tvm.testing.assert_allclose(res[i].asnumpy(), ref_res[i])
 
+@tvm.testing.uses_gpu
 def test_split_no_fuse():
     x = relay.var('x', shape=(12,))
     y = relay.split(x, 3, axis=0).astuple()
@@ -88,8 +87,9 @@ def test_split_no_fuse():
     z = relay.annotation.stop_fusion(z)
     f = relay.Function([x], z)
     x_data = np.random.rand(12,).astype('float32')
-    res = veval(f, x_data)
-    tvm.testing.assert_allclose(res.asnumpy(), np.split(x_data, 3, axis=0)[0])
+    for tgt, ctx in tvm.testing.enabled_targets():
+        res = veval(f, x_data, ctx=ctx, target=tgt)
+        tvm.testing.assert_allclose(res.asnumpy(), np.split(x_data, 3, axis=0)[0])
 
 @tvm.testing.uses_gpu
 def test_id():
@@ -147,6 +147,7 @@ def test_simple_if():
     # diff
     check_result([x_data, y_data], y_data, mod=mod)
 
+@tvm.testing.uses_gpu
 def test_multiple_ifs():
     mod = tvm.IRModule({})
     b = relay.var('b')
@@ -197,8 +198,9 @@ def test_count_loop():
     i_data = np.array(0, dtype='int32')
     iarg = relay.var('i', shape=[], dtype='int32')
     mod["main"] = relay.Function([iarg], sum_up(iarg))
-    result = veval(mod, i_data)
-    tvm.testing.assert_allclose(result.asnumpy(), i_data)
+    for tgt, ctx in tvm.testing.enabled_targets():
+        result = veval(mod, i_data, ctx=ctx, target=tgt)
+        tvm.testing.assert_allclose(result.asnumpy(), i_data)
     check_result([i_data], i_data, mod=mod)
 
 @tvm.testing.uses_gpu
@@ -246,6 +248,7 @@ def test_tuple_second():
     mod["main"] = f
     check_result([(i_data, j_data)], j_data, mod=mod)
 
+@tvm.testing.uses_gpu
 def test_list_constructor():
     mod = tvm.IRModule()
     p = Prelude(mod)
@@ -261,12 +264,13 @@ def test_list_constructor():
 
     mod["main"] = f
 
-    result = veval(mod)
-    assert len(result) == 2
-    assert len(result[1]) == 2
+    for tgt, ctx in tvm.testing.enabled_targets():
+        result = veval(mod, ctx=ctx, target=tgt)
+        assert len(result) == 2
+        assert len(result[1]) == 2
 
-    obj = vmobj_to_list(result)
-    tvm.testing.assert_allclose(obj, np.array([3,2,1]))
+        obj = vmobj_to_list(result)
+        tvm.testing.assert_allclose(obj, np.array([3,2,1]))
 
 @tvm.testing.uses_gpu
 def test_let_tensor():
@@ -304,6 +308,7 @@ def test_let_scalar():
     mod["main"] = f
     check_result([x_data], x_data + 42.0, mod=mod)
 
+@tvm.testing.uses_gpu
 def test_compose():
     mod = tvm.IRModule()
     p = Prelude(mod)
@@ -334,9 +339,11 @@ def test_compose():
     mod["main"] = f
 
     x_data = np.array(np.random.rand()).astype('float32')
-    result = veval(mod, [x_data])
-    tvm.testing.assert_allclose(result.asnumpy(), x_data + 2.0)
+    for tgt, ctx in tvm.testing.enabled_targets():
+        result = veval(mod, [x_data], ctx=ctx, target=tgt)
+        tvm.testing.assert_allclose(result.asnumpy(), x_data + 2.0)
 
+@tvm.testing.uses_gpu
 def test_list_hd():
     mod = tvm.IRModule()
     p = Prelude(mod)
@@ -354,8 +361,9 @@ def test_list_hd():
 
     mod["main"] = f
 
-    result = veval(mod)
-    tvm.testing.assert_allclose(result.asnumpy(), 3)
+    for tgt, ctx in tvm.testing.enabled_targets():
+        result = veval(mod, ctx=ctx, target=tgt)
+        tvm.testing.assert_allclose(result.asnumpy(), 3)
 
 @pytest.mark.xfail
 def test_list_tl_empty_list():
@@ -370,9 +378,10 @@ def test_list_tl_empty_list():
 
     mod["main"] = f
 
-    result = veval(mod)
-    print(result)
+    for tgt, ctx in tvm.testing.enabled_targets():
+        result = veval(mod, ctx=ctx, target=tgt)
 
+@tvm.testing.uses_gpu
 def test_list_tl():
     mod = tvm.IRModule()
     p = Prelude(mod)
@@ -390,9 +399,11 @@ def test_list_tl():
 
     mod["main"] = f
 
-    result = veval(mod)
-    tvm.testing.assert_allclose(vmobj_to_list(result), np.array([2,1]))
+    for tgt, ctx in tvm.testing.enabled_targets():
+        result = veval(mod, ctx=ctx, target=tgt)
+        tvm.testing.assert_allclose(vmobj_to_list(result), np.array([2,1]))
 
+@tvm.testing.uses_gpu
 def test_list_nth():
     expected = list(range(10))
 
@@ -409,9 +420,11 @@ def test_list_nth():
 
         f = relay.Function([], nth(l, relay.const(i)))
         mod["main"] = f
-        result = veval(mod)
-        tvm.testing.assert_allclose(result.asnumpy(), expected[i])
+        for tgt, ctx in tvm.testing.enabled_targets():
+            result = veval(mod, ctx=ctx, target=tgt)
+            tvm.testing.assert_allclose(result.asnumpy(), expected[i])
 
+@tvm.testing.uses_gpu
 def test_list_update():
     expected = list(range(10))
 
@@ -433,9 +446,11 @@ def test_list_update():
 
     f = relay.Function([], l)
     mod["main"] = f
-    result = veval(mod)
-    tvm.testing.assert_allclose(vmobj_to_list(result), np.array(expected))
+    for tgt, ctx in tvm.testing.enabled_targets():
+        result = veval(mod, ctx=ctx, target=tgt)
+        tvm.testing.assert_allclose(vmobj_to_list(result), np.array(expected))
 
+@tvm.testing.uses_gpu
 def test_list_length():
     expected = list(range(10))
 
@@ -455,9 +470,11 @@ def test_list_length():
 
     f = relay.Function([], l)
     mod["main"] = f
-    result = veval(mod)
-    tvm.testing.assert_allclose(result.asnumpy(), 10)
+    for tgt, ctx in tvm.testing.enabled_targets():
+        result = veval(mod, ctx=ctx, target=tgt)
+        tvm.testing.assert_allclose(result.asnumpy(), 10)
 
+@tvm.testing.uses_gpu
 def test_list_map():
     mod = tvm.IRModule()
     p = Prelude(mod)
@@ -473,9 +490,11 @@ def test_list_map():
 
     f = relay.Function([], map(add_one_func, l))
     mod["main"] = f
-    result = veval(mod)
-    tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 2]))
+    for tgt, ctx in tvm.testing.enabled_targets():
+        result = veval(mod, ctx=ctx, target=tgt)
+        tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 2]))
 
+@tvm.testing.uses_gpu
 def test_list_foldl():
     mod = tvm.IRModule()
     p = Prelude(mod)
@@ -491,9 +510,11 @@ def test_list_foldl():
     l = cons(relay.const(1), cons(relay.const(2), cons(relay.const(3), nil())))
     f = relay.Function([], foldl(rev_dup_func, nil(), l))
     mod["main"] = f
-    result = veval(mod)
-    tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 3, 2, 2, 1, 1]))
+    for tgt, ctx in tvm.testing.enabled_targets():
+        result = veval(mod, ctx=ctx, target=tgt)
+        tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 3, 2, 2, 1, 1]))
 
+@tvm.testing.uses_gpu
 def test_list_foldr():
     mod = tvm.IRModule()
     p = Prelude(mod)
@@ -509,9 +530,11 @@ def test_list_foldr():
     l = cons(relay.const(1), cons(relay.const(2), cons(relay.const(3), nil())))
     f = relay.Function([], foldr(identity_func, nil(), l))
     mod["main"] = f
-    result = veval(mod)
-    tvm.testing.assert_allclose(vmobj_to_list(result), np.array([1, 2, 3]))
+    for tgt, ctx in tvm.testing.enabled_targets():
+        result = veval(mod, ctx=ctx, target=tgt)
+        tvm.testing.assert_allclose(vmobj_to_list(result), np.array([1, 2, 3]))
 
+@tvm.testing.uses_gpu
 def test_list_sum():
     mod = tvm.IRModule()
     p = Prelude(mod)
@@ -523,9 +546,11 @@ def test_list_sum():
     l = cons(relay.const(1), cons(relay.const(2), cons(relay.const(3), nil())))
     f = relay.Function([], sum(l))
     mod["main"] = f
-    result = veval(mod)
-    tvm.testing.assert_allclose(result.asnumpy(), 6)
+    for tgt, ctx in tvm.testing.enabled_targets():
+        result = veval(mod, ctx=ctx, target=tgt)
+        tvm.testing.assert_allclose(result.asnumpy(), 6)
 
+@tvm.testing.uses_gpu
 def test_list_filter():
     mod = tvm.IRModule()
     p = Prelude(mod)
@@ -543,9 +568,11 @@ def test_list_filter():
                         cons(relay.const(1), nil())))))
     f = relay.Function([], filter(greater_than_one, l))
     mod["main"] = f
-    result = veval(mod)
-    tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 5]))
+    for tgt, ctx in tvm.testing.enabled_targets():
+        result = veval(mod, ctx=ctx, target=tgt)
+        tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 5]))
 
+@tvm.testing.uses_gpu
 def test_closure():
     x = relay.var('x', shape=())
     y = relay.var('y', shape=())
@@ -553,8 +580,9 @@ def test_closure():
     ff = relay.Function([y], f)
     clo = ff(relay.const(1.0))
     main = clo(relay.const(2.0))
-    res = veval(main)
-    tvm.testing.assert_allclose(res.asnumpy(), 3.0)
+    for tgt, ctx in tvm.testing.enabled_targets():
+        res = veval(main, ctx=ctx, target=tgt)
+        tvm.testing.assert_allclose(res.asnumpy(), 3.0)
 
 @tvm.testing.uses_gpu
 def test_add_op_scalar():