You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/03/30 07:39:04 UTC

[tvm] branch main updated: [ONNX] Enable GPU in ONNX importer tests (#7438)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 1c2555a  [ONNX] Enable GPU in ONNX importer tests (#7438)
1c2555a is described below

commit 1c2555a21322be5f45fa42b30f87a04b4915cd83
Author: Lily Orth-Smith <li...@gmail.com>
AuthorDate: Tue Mar 30 00:38:36 2021 -0700

    [ONNX] Enable GPU in ONNX importer tests (#7438)
    
    * remove hardcoded target and ctx
    
    * fix c-codgen for floating point mod
    
    * MDisable onnx gpu test for argmin / argmax so we can get this fix merged. Matt or myself will fix later but we don't have time right now.
    
    * lint
    
    * fix black
    
    * Add flag to task_python_frontend.sh to only run GPU enabled tests on GPU
    
    * black again
    
    * Enable GPU for test_nonzero
    
    * Respond to comments
    
    * Don't test batch matmul on CUDA
    
    * Turn cuda off for dynamic batch matmul test
    
    * Fix task script
    
    * Flaky test
    
    * another flaky test
    
    Co-authored-by: mbrookhart <mb...@octoml.ai>
---
 src/target/source/codegen_c.cc             | 16 ++++++++++++-
 tests/python/frontend/onnx/test_forward.py | 38 +++++++++++-------------------
 tests/scripts/task_python_frontend.sh      |  5 +++-
 3 files changed, 33 insertions(+), 26 deletions(-)

diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc
index 986ef8e..5eaf1e8 100644
--- a/src/target/source/codegen_c.cc
+++ b/src/target/source/codegen_c.cc
@@ -525,7 +525,21 @@ void CodeGenC::VisitExpr_(const DivNode* op, std::ostream& os) {  // NOLINT(*)
   PrintBinaryExpr(op, "/", os, this);
 }
 void CodeGenC::VisitExpr_(const ModNode* op, std::ostream& os) {  // NOLINT(*)
-  PrintBinaryExpr(op, "%", os, this);
+  if (op->dtype.is_int() || op->dtype.is_uint()) {
+    PrintBinaryExpr(op, "%", os, this);
+  } else {
+    ICHECK(op->dtype.is_float()) << "Expected floating point or integer dtype in Mod, but got "
+                                 << op->dtype;
+    if (op->dtype.bits() == 32) {
+      PrintBinaryExpr(op, "fmodf", os, this);
+    } else if (op->dtype.bits() == 64) {
+      PrintBinaryExpr(op, "fmod", os, this);
+    } else {
+      ICHECK(false)
+          << "Non single or double precision floating point in Mod, expected 32 or 64 bits but got "
+          << op->dtype.bits() << " bits.";
+    }
+  }
 }
 void CodeGenC::VisitExpr_(const MinNode* op, std::ostream& os) {  // NOLINT(*)
   PrintBinaryExpr(op, "min", os, this);
diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py
index 04b6c94..ce05554 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -276,8 +276,7 @@ def test_double_reshape():
         tvm.testing.assert_allclose(ref_shape, tvm_out.shape)
 
 
-# TODO(mbrookhart): enable once VM supports heterogenous execution
-# @tvm.testing.uses_gpu
+@tvm.testing.uses_gpu
 def test_expand():
     def _test_expand(name, data, shape, ref_data, dtype="int32"):
         shape_array = np.array(shape)
@@ -757,8 +756,7 @@ def _test_slice_iteration_v10(indata, outdata, **attrs):
     verify_with_ort_with_inputs(model, [indata], opset=10, freeze_params=True, use_vm=True)
 
 
-# TODO(mbrookhart): enable once VM supports heterogenous execution
-# @tvm.testing.uses_gpu
+@tvm.testing.uses_gpu
 def test_slice():
     x = np.random.randn(20, 10, 5).astype(np.float32)
     _test_slice_iteration_v1(x, x[0:3, 0:10], starts=(0, 0), ends=(3, 10), axes=(0, 1))
@@ -978,8 +976,7 @@ def test_gather_nd():
     verify_gather_nd([2, 2, 2], [[[0, 1]], [[1, 0]]], [2, 1, 2])
 
 
-# TODO(mbrookhart): enable once VM supports heterogenous execution
-# @tvm.testing.uses_gpu
+@tvm.testing.uses_gpu
 def test_onehot():
     indices_shape = [10]
     indices_array = np.random.randint(low=0, high=9, size=indices_shape, dtype="int32")
@@ -1091,7 +1088,7 @@ def verify_batch_matmul(a_shape, b_shape, out_shape, target, dev):
     verify_with_ort_with_inputs(model, [a_array, b_array], use_vm=True, targets=[target])
 
 
-# TODO(mbrookhart): enable cuda once VM supports heterogenous execution
+# TODO(mbrookhart, electriclilies): Add CUDA as a target once batch matmul is fixed
 @tvm.testing.parametrize_targets("llvm")
 def test_batch_matmul(target, dev):
     verify_batch_matmul((2, 3, 4, 3), (2, 3, 3, 4), (2, 3, 4, 4), target, dev)
@@ -1146,7 +1143,7 @@ def verify_simple_dynamic_model(a_shape, b_shape, target, dev):
     verify_model(ex, [a * 3 for a in a_shape], [b * 3 for b in b_shape])
 
 
-# TODO(mbrookhart): enable cuda once VM supports heterogenous execution
+# TODO(mbrookhart, electriclilies): Add CUDA as a target once batch matmul is fixed
 @tvm.testing.parametrize_targets("llvm")
 def test_batch_matmul_dynamic_model(target, dev):
     verify_simple_dynamic_model((2, 3, 4, 3), (2, 3, 3, 4), target, dev)
@@ -1319,8 +1316,7 @@ def verify_upsample3d_trilinear():
         tvm.testing.assert_allclose(out_array, tvm_out, rtol=1e-5, atol=1e-5)
 
 
-# TODO(mbrookhart): enable once VM supports heterogenous execution
-# @tvm.testing.uses_gpu
+@tvm.testing.uses_gpu
 def test_upsample():
     verify_upsample_nearest()
     verify_upsample_bilinear()
@@ -1497,7 +1493,8 @@ def verify_argreduce(input_dim, op_name, axis=None, keepdims=None):
     verify_with_ort_with_inputs(model, [a_np1])
 
 
-@tvm.testing.uses_gpu
+# TODO (mbrookhart, electriclilies) Fix argmin on GPU and enable this test
+# @tvm.testing.uses_gpu
 def test_forward_arg_min_max():
     """Verify argmin and argmax"""
     verify_argreduce([3, 4, 4], "ArgMin")
@@ -1540,8 +1537,7 @@ def verify_constantofshape(input_dim, value, dtype):
     verify_with_ort_with_inputs(model, [input_np], use_vm=True)
 
 
-# TODO(mbrookhart): enable once VM supports heterogenous execution
-# @tvm.testing.uses_gpu
+@tvm.testing.uses_gpu
 def test_constantofshape():
     verify_constantofshape((2, 3, 4, 5), 10, "float32")
     verify_constantofshape((3, 3), 0, "int32")
@@ -1627,8 +1623,7 @@ def verify_pad_v11(indata, pads, mode="constant", value=0.0):
     verify_with_ort_with_inputs(model, inputs, opset=11, use_vm=True)
 
 
-# TODO(mbrookhart): enable once VM supports heterogenous execution
-# @tvm.testing.uses_gpu
+@tvm.testing.uses_gpu
 def test_pad():
     verify_pad(np.random.randn(2, 2).astype(np.float32), [0, 1, 0, 0], "constant", 0.0)
     verify_pad(np.random.randn(2, 3).astype(np.float32), [1, 0, 0, 1], "constant", 0.0)
@@ -2120,8 +2115,7 @@ def verify_tile_v6(indata, repeats, outdata):
     verify_with_ort_with_inputs(model, [indata, repeats], use_vm=True, opset=6)
 
 
-# TODO(mbrookhart): enable once VM supports heterogenous execution
-# @tvm.testing.uses_gpu
+@tvm.testing.uses_gpu
 def test_tile():
     x = np.random.rand(2, 3, 4, 5).astype(np.float32)
     repeats = np.random.randint(low=1, high=10, size=(np.ndim(x),)).astype(np.int64)
@@ -2293,8 +2287,7 @@ def test_batch_norm():
     verify_batch_norm([16, 16, 10, 10])
 
 
-# TODO(mbrookhart): enable once VM supports heterogenous execution
-# @tvm.testing.uses_gpu
+@tvm.testing.uses_gpu
 def test_batch_norm_dynamic_subgraph():
     def verify_batch_norm_dynamic_subgraph(in_shape, o_shape):
 
@@ -3312,8 +3305,7 @@ def test_gru():
     )
 
 
-# TODO(mbrookhart): enable once VM supports heterogenous execution
-# @tvm.testing.uses_gpu
+@tvm.testing.uses_gpu
 def test_resize():
     def verify(ishape, oshape, scales, mode, coord_trans):
         nodes = [
@@ -3420,9 +3412,7 @@ def test_nonzero():
 
         model = helper.make_model(graph, producer_name="nonzero_test")
 
-        verify_with_ort_with_inputs(
-            model, [indata], targets=["llvm"], dtype="int64", use_vm=True, opset=9
-        )
+        verify_with_ort_with_inputs(model, [indata], dtype="int64", use_vm=True, opset=9)
 
     input_data = np.array([[1, 0], [1, 1]], dtype=np.int64)
     result = np.array((np.nonzero(input_data)))  # expected output [[0, 1, 1], [0, 0, 1]]
diff --git a/tests/scripts/task_python_frontend.sh b/tests/scripts/task_python_frontend.sh
index 62a0fa1..fb388a6 100755
--- a/tests/scripts/task_python_frontend.sh
+++ b/tests/scripts/task_python_frontend.sh
@@ -35,7 +35,10 @@ echo "Running relay MXNet frontend test..."
 run_pytest cython python-frontend-mxnet tests/python/frontend/mxnet
 
 echo "Running relay ONNX frontend test..."
-run_pytest cython python-frontend-onnx tests/python/frontend/onnx
+# Enable tvm.testing decorators in the ONNX importer test (not enabling in the other tests because we
+# they do not consistently use the decorators to indicate that tests should run on GPU)
+# In the future, we should enable tvm.testing decorators for all the test files.
+PYTEST_ADDOPTS="-m gpu $PYTEST_ADDOPTS" run_pytest cython python-frontend-onnx tests/python/frontend/onnx
 
 echo "Running relay CoreML frontend test..."
 run_pytest cython python-frontend-coreml tests/python/frontend/coreml