You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jr...@apache.org on 2021/03/24 02:40:50 UTC

[tvm] branch main updated: [ONNX] Onnx node tests (#7720)

This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8131364  [ONNX] Onnx node tests (#7720)
8131364 is described below

commit 813136401a11a49d6c15e6013c34dd822a5c4ff6
Author: Matthew Brookhart <mb...@octoml.ai>
AuthorDate: Tue Mar 23 20:40:32 2021 -0600

    [ONNX] Onnx node tests (#7720)
    
    * WIP
    
    * some fixes
    
    * more fixes
    
    * fix some conv_transpose tests
    
    * fix out of bounds slice
    
    * fix flatten import
    
    * fix logsoftmax and softmax tests
    
    * fix Error in Upsample
    
    * fix onehot
    
    * normalize errors
    
    * fix gather with negative indices
    
    * parameterize test
    
    * skip unsupported tests
    
    * clean up
    
    * fix rebase
    
    * fix lint
    
    * add an error message when we find an un-identified tensor
---
 python/tvm/relay/frontend/onnx.py          | 133 +++++++++++++++++------
 python/tvm/relay/op/transform.py           |   7 +-
 tests/python/frontend/onnx/test_forward.py | 163 +++++++++++++++++++++++++++++
 3 files changed, 269 insertions(+), 34 deletions(-)

diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py
index fab4ae8..d9fc2ff 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -103,10 +103,11 @@ def get_numpy(tensor_proto):
 def get_type(elem_type):
     """Converts onnx integer datatype to numpy datatype"""
     try:
-        from onnx import TensorProto
+        from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
     except ImportError as e:
         raise ImportError("Unable to import onnx which is required {}".format(e))
-    return TensorProto.DataType.Name(elem_type).lower()
+
+    return str(TENSOR_TYPE_TO_NP_TYPE[elem_type])
 
 
 def get_info(info_proto):
@@ -157,7 +158,7 @@ def revert_caffe2_pad(pads):
     return pads
 
 
-def get_pad_pair(input1d, kernel1d, stride1d):
+def get_pad_pair(input1d, kernel1d, stride1d, mode):
     """infer pad size"""
     if input1d % stride1d == 0:
         pad = max(kernel1d - stride1d, 0)
@@ -165,6 +166,8 @@ def get_pad_pair(input1d, kernel1d, stride1d):
         pad = max(kernel1d - (input1d % stride1d), 0)
     pad_before = pad // 2
     pad_after = pad - pad_before
+    if "LOWER" in mode:
+        return [pad_after, pad_before]
     return [pad_before, pad_after]
 
 
@@ -280,9 +283,9 @@ class Pool(OnnxOpConverter):
                     pad_tuple = []
                     for axis in range(len(input_shape) - 2):
                         axis_shape = input_shape[2 + axis]
-                        stride = attr["strides"][axis]
+                        stride = attr.get("strides", [1] * ndim)[axis]
                         kernel = attr["kernel_shape"][axis]
-                        pad = get_pad_pair(axis_shape, kernel, stride)
+                        pad = get_pad_pair(axis_shape, kernel, stride, attr["auto_pad"])
                         pad_tuple.append(pad)
                     pad_tuple = tuple([val for pair in zip(*pad_tuple) for val in pair])
                     attr["pads"] = pad_tuple
@@ -444,9 +447,15 @@ class ConvTranspose(OnnxOpConverter):
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
         # get number of channels
-        channels = infer_channels(inputs[1], True)
+        out_type = infer_type(inputs[1])
+        out_shapes = [get_const_tuple(out_type.checked_type.shape)]
+        channels = out_shapes[0][1]
         attr["channels"] = channels
         groups = attr.get("group", 1)
+
+        if "kernel_shape" not in attr:
+            attr["kernel_shape"] = out_shapes[0][2:]
+
         attr["groups"] = groups
         # infer pads for auto_pad
         data = inputs[0]
@@ -528,13 +537,11 @@ class Gemm(OnnxOpConverter):
         if not transB:
             inputs[1] = _op.transpose(inputs[1], axes=(1, 0))
         inputs[0] = _op.nn.batch_flatten(inputs[0])
-
         if alpha != 1.0:
             inputs[0] *= _expr.const(alpha)
         out = _op.nn.dense(inputs[0], inputs[1], units=channels)
-
         if len(inputs) == 3:
-            return _op.nn.bias_add(out, _expr.const(beta) * inputs[2])
+            out = out + _expr.const(beta) * inputs[2]
         return out
 
 
@@ -618,7 +625,7 @@ class Mod(OnnxOpConverter):
         # Note: attr['fmod'] determines whether the operator should behave like np.fmod or np.mod.
         # attr['fmod'] == 0 will behave as np.mod and attr['fmod'] == 1 will force fmod treatment.
         # The relay equivalent of np.fmod is relay.mod and np.mod is relay.floor_mod
-        if attr["fmod"] == 0:
+        if attr.get("fmod", 0) == 0:
             op_name = "floor_mod"
         else:
             op_name = "mod"
@@ -849,12 +856,18 @@ class Flatten(OnnxOpConverter):
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
         axis = attr.get("axis", 1)
+        ishape = _op.shape_of(inputs[0])
+        ndim = infer_shape(ishape)[0]
+        if axis < 0:
+            axis = axis + ndim
+
         if axis == 1:
             out = _op.nn.batch_flatten(inputs[0])
         else:
-            newshape = [0] * (axis + 1)
-            newshape[axis] = -1
-            out = _op.reshape(inputs[0], list(newshape))
+            pre_shape = _op.prod(_op.strided_slice(ishape, [0], [axis], [1]), keepdims=True)
+            post_shape = _op.prod(_op.strided_slice(ishape, [axis], [ndim], [1]), keepdims=True)
+            newshape = _op.concatenate([pre_shape, post_shape], axis=0)
+            out = _op.reshape(inputs[0], newshape)
         return out
 
 
@@ -1036,7 +1049,7 @@ class Upsample(OnnxOpConverter):
 
         # in 3d case, we use the purely static op
         if dims == 5:
-            if isinstance(scales, _expr.Call):
+            if isinstance(scales, _expr.Expr):
                 scale_h = _op.take(scales, _op.const(3))
                 scale_w = _op.take(scales, _op.const(4))
                 scale_d = _op.take(scales, _op.const(1))
@@ -1052,7 +1065,7 @@ class Upsample(OnnxOpConverter):
             )
         # in 2d case, use dynamic op
         else:
-            if isinstance(scales, _expr.Call):
+            if isinstance(scales, _expr.Expr):
                 scale_h = _op.take(scales, _op.const(3))
                 scale_w = _op.take(scales, _op.const(4))
             else:
@@ -1247,7 +1260,13 @@ class Gather(OnnxOpConverter):
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
         axis = attr.get("axis", 0)
-        return AttrCvt("take", extras={"axis": axis})(inputs, {})
+        data = inputs[0]
+        indices = inputs[1]
+        ind_dtype = infer_type(indices).checked_type.dtype
+        # Normalize the indices to a positive range
+        s = _op.take(_op.shape_of(data, dtype=ind_dtype), _op.const(axis))
+        indices = _op.where(indices < _op.const(0, ind_dtype), indices + s, indices)
+        return _op.take(data, indices, axis)
 
 
 class GatherElements(OnnxOpConverter):
@@ -1258,6 +1277,10 @@ class GatherElements(OnnxOpConverter):
         data = inputs[0]
         indices = inputs[1]
         axis = attr.get("axis", 0)
+        ind_dtype = infer_type(indices).checked_type.dtype
+        # Normalize the indices to a positive range
+        s = _op.take(_op.shape_of(data, dtype=ind_dtype), _op.const(axis))
+        indices = _op.where(indices < _op.const(0, ind_dtype), indices + s, indices)
         return _op.gather(data, axis, indices)
 
 
@@ -1318,8 +1341,8 @@ class Maximum(OnnxOpConverter):
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2:
-            raise ValueError("Expect minimum 2 inputs")
+        if len(inputs) == 1:
+            return inputs[0]
         _max = inputs[0]
         for i in range(1, len(inputs)):
             _max = AttrCvt("maximum")([_max, inputs[i]], {})
@@ -1331,8 +1354,8 @@ class Minimum(OnnxOpConverter):
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2:
-            raise ValueError("Expect minimum 2 inputs")
+        if len(inputs) == 1:
+            return inputs[0]
         _min = inputs[0]
         for i in range(1, len(inputs)):
             _min = AttrCvt("minimum")([_min, inputs[i]], {})
@@ -1344,8 +1367,8 @@ class Mean(OnnxOpConverter):
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2:
-            raise ValueError("Expect minimum 2 inputs")
+        if len(inputs) == 1:
+            return inputs[0]
         # avoid overflow
         concat = _op.concatenate([_op.expand_dims(x, axis=0) for x in inputs], axis=0)
         return _op.mean(concat, axis=0, keepdims=False)
@@ -1485,6 +1508,8 @@ class ArgMax(OnnxOpConverter):
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
+        if "select_last_index" in attr:
+            raise NotImplementedError("select_last_index not supported in ArgMax")
         axis = attr.get("axis", 0)
         keepdims = attr.get("keepdims", True)
         attr = {"axis": axis, "keepdims": keepdims}
@@ -1496,6 +1521,8 @@ class ArgMin(OnnxOpConverter):
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
+        if "select_last_index" in attr:
+            raise NotImplementedError("select_last_index not supported in ArgMin")
         axis = attr.get("axis", 0)
         keepdims = attr.get("keepdims", True)
         attr = {"axis": axis, "keepdims": keepdims}
@@ -1510,7 +1537,35 @@ class Softmax(OnnxOpConverter):
         # set default value when axis is not set in the model
         if "axis" not in attr:
             attr["axis"] = 1
-        return AttrCvt("softmax", transforms={"axis": ("axis", 1)})(inputs, attr, params)
+        axis = attr["axis"]
+        ndim = len(infer_shape(inputs[0]))
+        if axis < 0:
+            axis += ndim
+        axes = list(range(axis, ndim))
+        x = inputs[0]
+        m = _op.max(x, axes, keepdims=True)
+        e = _op.exp(x - m)
+        return e / _op.sum(e, axes, keepdims=True)
+
+
+class LogSoftmax(OnnxOpConverter):
+    """Operator converter for Softmax."""
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        # set default value when axis is not set in the model
+        if "axis" not in attr:
+            attr["axis"] = 1
+        axis = attr["axis"]
+        ndim = len(infer_shape(inputs[0]))
+        if axis < 0:
+            axis += ndim
+        axes = list(range(axis, ndim))
+        x = inputs[0]
+        m = _op.max(x, axes, keepdims=True)
+        e = _op.exp(x - m)
+        s = _op.sum(e, axes, keepdims=True)
+        return x - m - _op.log(s)
 
 
 class OneHot(OnnxOpConverter):
@@ -1520,14 +1575,24 @@ class OneHot(OnnxOpConverter):
     def _impl_v9(cls, inputs, attr, params):
         # Extract relay one_hot inputs.
         indices, depth, values = inputs
+        ndim = len(infer_shape(indices))
         # Split onnx on off values into two separate expressions.
         off_value, on_value = _op.take(values, _op.const(0)), _op.take(values, _op.const(1))
         # Extract the datatype of the output from on_value.
         dtype = infer_type(on_value).checked_type.dtype
+        ind_dtype = infer_type(indices).checked_type.dtype
+        # Normalize the indices to a positive range
+        indices = _op.where(
+            indices < _op.const(0, ind_dtype), indices + _op.cast(depth, ind_dtype), indices
+        )
         # set default value when axis is not set in the model
         if "axis" not in attr:
             attr["axis"] = -1
-        return _op.one_hot(indices, on_value, off_value, depth, int(attr["axis"]), dtype=dtype)
+        axis = attr["axis"]
+        if axis < 0:
+            axis += ndim + 1
+
+        return _op.one_hot(indices, on_value, off_value, depth, axis, dtype=dtype)
 
 
 class ConstantOfShape(OnnxOpConverter):
@@ -1552,7 +1617,7 @@ class Constant(OnnxOpConverter):
     @classmethod
     def _impl_v9(cls, inputs, attr, params):
         if "value" not in attr:
-            raise "No Value in Constant"
+            raise tvm.errors.OpAttributeRequired("no value in Constant")
         np_value = get_numpy(attr.pop("value"))
         dtype = np_value.dtype.name
         value = _expr.const(np_value, dtype)
@@ -2042,7 +2107,7 @@ class TopK(OnnxOpConverter):
         largest = attr.get("largest", 1)
 
         if largest == 0:
-            raise ValueError("TVM only supports finding TopK largest elements")
+            raise NotImplementedError("TVM only supports finding TopK largest elements")
 
         return _op.topk(inputs[0], inputs[1], axis=axis, dtype="int64")
 
@@ -2087,7 +2152,7 @@ class RoiAlign(OnnxOpConverter):
         batch_indices = inputs[2]
         mode = attr.get("mode", b"avg")
         if mode not in (b"avg", b"max"):
-            raise ValueError("RoiAlign in Relay only uses avg and max modes")
+            raise NotImplementedError("RoiAlign in Relay only uses avg and max modes")
         output_height = attr.get("output_height", 1)
         output_width = attr.get("output_width", 1)
 
@@ -2128,7 +2193,8 @@ class Clip(OnnxOpConverter):
         result = inputs[0]
         for i, op in enumerate([_op.tensor.maximum, _op.tensor.minimum]):
             if i < len(inputs) - 1:
-                result = op(result, inputs[i + 1])
+                if inputs[i + 1] is not None:
+                    result = op(result, inputs[i + 1])
         return result
 
 
@@ -2393,9 +2459,10 @@ class NonMaxSuppression(OnnxOpConverter):
         dtype = infer_type(boxes).checked_type.dtype
 
         if "center_point_box" in attr:
-            assert (
-                attr["center_point_box"] == 0
-            ), "Only support center_point_box = 0 in onnx importer right now"
+            if attr["center_point_box"] != 0:
+                raise NotImplementedError(
+                    "Only support center_point_box = 0 in ONNX NonMaxSuprresion"
+                )
 
         if iou_threshold is None:
             iou_threshold = _expr.const(0.0, dtype="float32")
@@ -2718,7 +2785,7 @@ def _get_convert_map(opset):
         "Softplus": Softplus.get_converter(opset),
         # softmax default axis is different in onnx
         "Softmax": Softmax.get_converter(opset),
-        "LogSoftmax": AttrCvt("log_softmax", {"axis": ("axis", 1)}),
+        "LogSoftmax": LogSoftmax.get_converter(opset),
         "OneHot": OneHot.get_converter(opset),
         # 'Hardmax'
         "Softsign": Softsign.get_converter(opset),
@@ -2958,6 +3025,8 @@ class GraphProto:
             for i in node.input:
                 if i != "":
                     inputs[i] = self._nodes[self._renames.get(i, i)]
+                else:
+                    inputs[i] = None
             i_name = self._parse_value_proto(node)
             node_output = self._fix_outputs(op_name, node.output)
             attr["tvm_custom"] = {}
diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py
index 4129b61..df0ae76 100644
--- a/python/tvm/relay/op/transform.py
+++ b/python/tvm/relay/op/transform.py
@@ -905,10 +905,13 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end"):
             end = const(list(end))
         if isinstance(strides, (tuple, list)):
             strides = const(list(strides))
-        normalized_begin = _make.where(
+        begin = _make.where(
             begin < cast_like(const(0), begin), begin + cast_like(shape_of(data), begin), begin
         )
-        return _dyn_make.strided_slice(data, normalized_begin, end, strides, slice_mode)
+        begin = _make.where(
+            begin >= cast_like(shape_of(data), begin), cast_like(shape_of(data), begin), begin
+        )
+        return _dyn_make.strided_slice(data, begin, end, strides, slice_mode)
     return _make.strided_slice(data, begin, end, strides, slice_mode)
 
 
diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py
index 5a6216a..ec89a3d 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -4090,6 +4090,169 @@ def test_cumsum():
     verify_cumsum(data, 1, 1, 1, type="int32")
 
 
+from onnx import numpy_helper
+
+f = onnx.__file__
+import glob
+
+onnx_test_folders = sorted(glob.glob("/".join(f.split("/")[0:-1]) + "/backend/test/data/node/*/"))
+
+unsupported_onnx_tests = [
+    "test_basic_convinteger/",
+    "test_bitshift_left_uint16/",
+    "test_bitshift_left_uint32/",
+    "test_bitshift_left_uint64/",
+    "test_bitshift_left_uint8/",
+    "test_bitshift_right_uint16/",
+    "test_bitshift_right_uint32/",
+    "test_bitshift_right_uint64/",
+    "test_bitshift_right_uint8/",
+    "test_cast_DOUBLE_to_FLOAT16/",
+    "test_cast_FLOAT16_to_DOUBLE/",
+    "test_cast_FLOAT16_to_FLOAT/",
+    "test_cast_FLOAT_to_FLOAT16/",
+    "test_cast_FLOAT_to_STRING/",
+    "test_cast_STRING_to_FLOAT/",
+    "test_compress_0/",
+    "test_compress_1/",
+    "test_compress_default_axis/",
+    "test_compress_negative_axis/",
+    "test_convinteger_with_padding/",
+    "test_convtranspose_dilations/",
+    "test_convtranspose_output_shape/",
+    "test_cumsum_1d/",
+    "test_cumsum_1d_exclusive/",
+    "test_cumsum_1d_reverse/",
+    "test_cumsum_1d_reverse_exclusive/",
+    "test_cumsum_2d_axis_0/",
+    "test_cumsum_2d_axis_1/",
+    "test_cumsum_2d_negative_axis/",
+    "test_dequantizelinear/",
+    "test_det_2d/",
+    "test_det_nd/",
+    "test_dynamicquantizelinear/",
+    "test_dynamicquantizelinear_expanded/",
+    "test_dynamicquantizelinear_max_adjusted/",
+    "test_dynamicquantizelinear_max_adjusted_expanded/",
+    "test_dynamicquantizelinear_min_adjusted/",
+    "test_dynamicquantizelinear_min_adjusted_expanded/",
+    "test_eyelike_populate_off_main_diagonal/",
+    "test_eyelike_with_dtype/",
+    "test_eyelike_without_dtype/",
+    "test_hardmax_axis_0/",
+    "test_hardmax_axis_1/",
+    "test_hardmax_axis_2/",
+    "test_hardmax_default_axis/",
+    "test_hardmax_example/",
+    "test_hardmax_negative_axis/",
+    "test_hardmax_one_hot/",
+    "test_isinf_negative/",
+    "test_isinf_positive/",
+    "test_lstm_defaults/",
+    "test_lstm_with_initial_bias/",
+    "test_lstm_with_peepholes/",
+    "test_matmulinteger/",
+    "test_maxpool_2d_dilations/",
+    "test_maxpool_2d_same_lower/",
+    "test_maxpool_2d_same_upper/",
+    "test_maxpool_with_argmax_2d_precomputed_pads/",
+    "test_maxpool_with_argmax_2d_precomputed_strides/",
+    "test_maxunpool_export_with_output_shape/",
+    "test_mvn/",
+    "test_nonmaxsuppression_center_point_box_format/",
+    "test_qlinearconv/",
+    "test_qlinearmatmul_2D/",
+    "test_qlinearmatmul_3D/",
+    "test_quantizelinear/",
+    "test_range_float_type_positive_delta_expanded/",
+    "test_range_int32_type_negative_delta_expanded/",
+    "test_resize_downsample_scales_cubic/",
+    "test_resize_downsample_scales_cubic_A_n0p5_exclude_outside/",
+    "test_resize_downsample_scales_cubic_align_corners/",
+    "test_resize_downsample_scales_linear/",
+    "test_resize_downsample_scales_nearest/",
+    "test_resize_downsample_sizes_cubic/",
+    "test_resize_downsample_sizes_linear_pytorch_half_pixel/",
+    "test_resize_downsample_sizes_nearest/",
+    "test_resize_downsample_sizes_nearest_tf_half_pixel_for_nn/",
+    "test_resize_tf_crop_and_resize/",
+    "test_resize_upsample_scales_cubic/",
+    "test_resize_upsample_scales_cubic_A_n0p5_exclude_outside/",
+    "test_resize_upsample_scales_cubic_align_corners/",
+    "test_resize_upsample_scales_cubic_asymmetric/",
+    "test_resize_upsample_scales_linear/",
+    "test_resize_upsample_sizes_cubic/",
+    "test_resize_upsample_sizes_nearest_ceil_half_pixel/",
+    "test_resize_upsample_sizes_nearest_floor_align_corners/",
+    "test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric/",
+    "test_reversesequence_batch/",
+    "test_reversesequence_time/",
+    "test_rnn_seq_length/",
+    "test_roialign/",
+    "test_round/",
+    "test_scan9_sum/",
+    "test_scan_sum/",
+    "test_scatternd/",
+    "test_selu_default/",
+    "test_shrink_hard/",
+    "test_shrink_soft/",
+    "test_simple_rnn_defaults/",
+    "test_simple_rnn_with_initial_bias/",
+    "test_slice_neg_steps/",
+    "test_slice_start_out_of_bounds/",
+    "test_strnormalizer_export_monday_casesensintive_lower/",
+    "test_strnormalizer_export_monday_casesensintive_nochangecase/",
+    "test_strnormalizer_export_monday_casesensintive_upper/",
+    "test_strnormalizer_export_monday_empty_output/",
+    "test_strnormalizer_export_monday_insensintive_upper_twodim/",
+    "test_strnormalizer_nostopwords_nochangecase/",
+    "test_tfidfvectorizer_tf_batch_onlybigrams_skip0/",
+    "test_tfidfvectorizer_tf_batch_onlybigrams_skip5/",
+    "test_tfidfvectorizer_tf_batch_uniandbigrams_skip5/",
+    "test_tfidfvectorizer_tf_only_bigrams_skip0/",
+    "test_tfidfvectorizer_tf_onlybigrams_levelempty/",
+    "test_tfidfvectorizer_tf_onlybigrams_skip5/",
+    "test_tfidfvectorizer_tf_uniandbigrams_skip5/",
+    "test_top_k_smallest/",
+    "test_unique_not_sorted_without_axis/",
+    "test_unique_sorted_with_axis/",
+    "test_unique_sorted_with_axis_3d/",
+    "test_unique_sorted_with_negative_axis/",
+    "test_unique_sorted_without_axis/",
+    "test_unsqueeze_unsorted_axes/",
+    "test_upsample_nearest/",
+]
+
+
+@pytest.mark.parametrize("test", onnx_test_folders)
+def test_onnx_nodes(test):
+    for failure in unsupported_onnx_tests:
+        if failure in test:
+            pytest.skip()
+            break
+    onnx_model = onnx.load(test + "/model.onnx")
+    inputs = []
+    outputs = []
+    for dataset in glob.glob(test + "/*/"):
+        tensors = sorted(glob.glob(dataset + "/*.pb"))
+        for tensor in tensors:
+            new_tensor = onnx.TensorProto()
+            with open(tensor, "rb") as f:
+                new_tensor.ParseFromString(f.read())
+            if "input" in tensor.split("/")[-1]:
+                inputs.append(numpy_helper.to_array(new_tensor))
+            elif "output" in tensor.split("/")[-1]:
+                outputs.append(numpy_helper.to_array(new_tensor))
+            else:
+                raise ImportError(str(tensor) + " not labeled as an import or an output")
+        tvm_val = get_tvm_output_with_vm(onnx_model, inputs, "llvm", tvm.cpu(0))
+        if len(outputs) == 1:
+            tvm.testing.assert_allclose(outputs[0], tvm_val, rtol=1e-5, atol=1e-5)
+        else:
+            for output, val in zip(outputs, tvm_val):
+                tvm.testing.assert_allclose(output, val, rtol=1e-5, atol=1e-5)
+
+
 def test_wrong_input():
     node = helper.make_node(
         "Softplus",