You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/03/22 18:23:40 UTC

[GitHub] [tvm] mbrookhart opened a new pull request #7720: [ONNX] Onnx node tests

mbrookhart opened a new pull request #7720:
URL: https://github.com/apache/tvm/pull/7720


   We've been hitting a lot of errors running models with ONNX, that has led to a lot of piecewise fixes. This is an attempt to fix the importer more broadly by running the tests onnx ships with pip https://github.com/onnx/onnx/tree/master/onnx/backend/test/data/node
   
   These files contain an onnx graph, input arrays, and expected outputs, so we can test directly against the canonical onnx tests. This PR provides a method to import these tests as parameterized unit tests, execute them, and skip any we know currently fail. I also fixed a lot of low hanging fruit to reduce the number of unit tests.
   
   Future PRs will work to fix the currently skipped tests, and then extend this to GPU.
   
   For reference, this is the pytest result on my system, testing against ONNX 1.6, which is what we have in CI:
   
   `434 passed, 123 skipped, 83 deselected, 1185 warnings in 32.40s`
   
   This adds a lot of tests, but they are all small, so the runtime is actually pretty minuscule, and it improves our ONNX import coverage dramatically.
   
   cc @jwfromm @masahi @jroesch @electriclilies @adelbertc 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] electriclilies commented on a change in pull request #7720: [ONNX] Onnx node tests

Posted by GitBox <gi...@apache.org>.
electriclilies commented on a change in pull request #7720:
URL: https://github.com/apache/tvm/pull/7720#discussion_r599033812



##########
File path: python/tvm/relay/frontend/onnx.py
##########
@@ -444,9 +448,15 @@ class ConvTranspose(OnnxOpConverter):
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
         # get number of channels
-        channels = infer_channels(inputs[1], True)
+        out_type = infer_type(inputs[1])
+        out_shapes = [get_const_tuple(out_type.checked_type.shape)]
+        channels = out_shapes[0][1]

Review comment:
       Does this need to work for layouts other than NCHW? It looks like the ONNX op doesn't specify layout in the ConvTranspose operator

##########
File path: tests/python/frontend/onnx/test_forward.py
##########
@@ -4090,6 +4090,170 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"):
     verify_cumsum(data, 1, 1, 1, type="int32")
 
 
+from onnx import numpy_helper
+
+f = onnx.__file__
+import glob
+
+onnx_test_folders = sorted(glob.glob("/".join(f.split("/")[0:-1]) + "/backend/test/data/node/*/"))
+
+unsupported_onnx_tests = [
+    "test_basic_convinteger/",
+    "test_bitshift_left_uint16/",
+    "test_bitshift_left_uint32/",
+    "test_bitshift_left_uint64/",
+    "test_bitshift_left_uint8/",
+    "test_bitshift_right_uint16/",
+    "test_bitshift_right_uint32/",
+    "test_bitshift_right_uint64/",
+    "test_bitshift_right_uint8/",
+    "test_cast_DOUBLE_to_FLOAT16/",
+    "test_cast_FLOAT16_to_DOUBLE/",
+    "test_cast_FLOAT16_to_FLOAT/",
+    "test_cast_FLOAT_to_FLOAT16/",
+    "test_cast_FLOAT_to_STRING/",
+    "test_cast_STRING_to_FLOAT/",
+    "test_compress_0/",
+    "test_compress_1/",
+    "test_compress_default_axis/",
+    "test_compress_negative_axis/",
+    "test_convinteger_with_padding/",
+    "test_convtranspose_dilations/",
+    "test_convtranspose_output_shape/",
+    "test_cumsum_1d/",
+    "test_cumsum_1d_exclusive/",
+    "test_cumsum_1d_reverse/",
+    "test_cumsum_1d_reverse_exclusive/",
+    "test_cumsum_2d_axis_0/",
+    "test_cumsum_2d_axis_1/",
+    "test_cumsum_2d_negative_axis/",
+    "test_dequantizelinear/",
+    "test_det_2d/",
+    "test_det_nd/",
+    "test_dynamicquantizelinear/",
+    "test_dynamicquantizelinear_expanded/",
+    "test_dynamicquantizelinear_max_adjusted/",
+    "test_dynamicquantizelinear_max_adjusted_expanded/",
+    "test_dynamicquantizelinear_min_adjusted/",
+    "test_dynamicquantizelinear_min_adjusted_expanded/",
+    "test_eyelike_populate_off_main_diagonal/",
+    "test_eyelike_with_dtype/",
+    "test_eyelike_without_dtype/",
+    "test_hardmax_axis_0/",
+    "test_hardmax_axis_1/",
+    "test_hardmax_axis_2/",
+    "test_hardmax_default_axis/",
+    "test_hardmax_example/",
+    "test_hardmax_negative_axis/",
+    "test_hardmax_one_hot/",
+    "test_isinf_negative/",
+    "test_isinf_positive/",
+    "test_lstm_defaults/",
+    "test_lstm_with_initial_bias/",
+    "test_lstm_with_peepholes/",
+    "test_matmulinteger/",
+    "test_maxpool_2d_dilations/",
+    "test_maxpool_2d_same_lower/",
+    "test_maxpool_2d_same_upper/",
+    "test_maxpool_with_argmax_2d_precomputed_pads/",
+    "test_maxpool_with_argmax_2d_precomputed_strides/",
+    "test_maxunpool_export_with_output_shape/",
+    "test_mvn/",
+    "test_nonmaxsuppression_center_point_box_format/",
+    "test_qlinearconv/",
+    "test_qlinearmatmul_2D/",
+    "test_qlinearmatmul_3D/",
+    "test_quantizelinear/",
+    "test_range_float_type_positive_delta_expanded/",
+    "test_range_int32_type_negative_delta_expanded/",
+    "test_resize_downsample_scales_cubic/",
+    "test_resize_downsample_scales_cubic_A_n0p5_exclude_outside/",
+    "test_resize_downsample_scales_cubic_align_corners/",
+    "test_resize_downsample_scales_linear/",
+    "test_resize_downsample_scales_nearest/",
+    "test_resize_downsample_sizes_cubic/",
+    "test_resize_downsample_sizes_linear_pytorch_half_pixel/",
+    "test_resize_downsample_sizes_nearest/",
+    "test_resize_downsample_sizes_nearest_tf_half_pixel_for_nn/",
+    "test_resize_tf_crop_and_resize/",
+    "test_resize_upsample_scales_cubic/",
+    "test_resize_upsample_scales_cubic_A_n0p5_exclude_outside/",
+    "test_resize_upsample_scales_cubic_align_corners/",
+    "test_resize_upsample_scales_cubic_asymmetric/",
+    "test_resize_upsample_scales_linear/",
+    "test_resize_upsample_sizes_cubic/",
+    "test_resize_upsample_sizes_nearest_ceil_half_pixel/",
+    "test_resize_upsample_sizes_nearest_floor_align_corners/",
+    "test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric/",
+    "test_reversesequence_batch/",
+    "test_reversesequence_time/",
+    "test_rnn_seq_length/",
+    "test_roialign/",
+    "test_round/",
+    "test_scan9_sum/",
+    "test_scan_sum/",
+    "test_scatternd/",
+    "test_selu_default/",
+    "test_shrink_hard/",
+    "test_shrink_soft/",
+    "test_simple_rnn_defaults/",
+    "test_simple_rnn_with_initial_bias/",
+    "test_slice_neg_steps/",
+    "test_slice_start_out_of_bounds/",
+    "test_strnormalizer_export_monday_casesensintive_lower/",
+    "test_strnormalizer_export_monday_casesensintive_nochangecase/",
+    "test_strnormalizer_export_monday_casesensintive_upper/",
+    "test_strnormalizer_export_monday_empty_output/",
+    "test_strnormalizer_export_monday_insensintive_upper_twodim/",
+    "test_strnormalizer_nostopwords_nochangecase/",
+    "test_tfidfvectorizer_tf_batch_onlybigrams_skip0/",
+    "test_tfidfvectorizer_tf_batch_onlybigrams_skip5/",
+    "test_tfidfvectorizer_tf_batch_uniandbigrams_skip5/",
+    "test_tfidfvectorizer_tf_only_bigrams_skip0/",
+    "test_tfidfvectorizer_tf_onlybigrams_levelempty/",
+    "test_tfidfvectorizer_tf_onlybigrams_skip5/",
+    "test_tfidfvectorizer_tf_uniandbigrams_skip5/",
+    "test_top_k_smallest/",
+    "test_unique_not_sorted_without_axis/",
+    "test_unique_sorted_with_axis/",
+    "test_unique_sorted_with_axis_3d/",
+    "test_unique_sorted_with_negative_axis/",
+    "test_unique_sorted_without_axis/",
+    "test_unsqueeze_unsorted_axes/",
+    "test_upsample_nearest/",
+]
+
+
+@pytest.mark.parametrize("test", onnx_test_folders)
+def test_onnx_nodes(test):
+    for failure in unsupported_onnx_tests:
+        if failure in test:
+            pytest.skip()
+            break
+    onnx_model = onnx.load(test + "/model.onnx")
+    inputs = []
+    outputs = []
+    for dataset in glob.glob(test + "/*/"):
+        tensors = sorted(glob.glob(dataset + "/*.pb"))
+        for tensor in tensors:
+            new_tensor = onnx.TensorProto()
+            with open(tensor, "rb") as f:
+                new_tensor.ParseFromString(f.read())
+            if "input" in tensor.split("/")[-1]:
+                inputs.append(numpy_helper.to_array(new_tensor))
+            elif "output" in tensor.split("/")[-1]:
+                outputs.append(numpy_helper.to_array(new_tensor))
+            else:
+                print(tensor)
+                raise

Review comment:
       Can you put an error message here? Maybe something like "Expected tensor to be either an input or output"




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbrookhart commented on a change in pull request #7720: [ONNX] Onnx node tests

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on a change in pull request #7720:
URL: https://github.com/apache/tvm/pull/7720#discussion_r599079051



##########
File path: python/tvm/relay/frontend/onnx.py
##########
@@ -444,9 +448,15 @@ class ConvTranspose(OnnxOpConverter):
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
         # get number of channels
-        channels = infer_channels(inputs[1], True)
+        out_type = infer_type(inputs[1])
+        out_shapes = [get_const_tuple(out_type.checked_type.shape)]
+        channels = out_shapes[0][1]

Review comment:
       ONNX always assumes NCHW




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] electriclilies commented on a change in pull request #7720: [ONNX] Onnx node tests

Posted by GitBox <gi...@apache.org>.
electriclilies commented on a change in pull request #7720:
URL: https://github.com/apache/tvm/pull/7720#discussion_r599082811



##########
File path: python/tvm/relay/frontend/onnx.py
##########
@@ -444,9 +448,15 @@ class ConvTranspose(OnnxOpConverter):
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
         # get number of channels
-        channels = infer_channels(inputs[1], True)
+        out_type = infer_type(inputs[1])
+        out_shapes = [get_const_tuple(out_type.checked_type.shape)]
+        channels = out_shapes[0][1]

Review comment:
       Cool, just wanted to make sure we didn't have to worry about it!




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch merged pull request #7720: [ONNX] Onnx node tests

Posted by GitBox <gi...@apache.org>.
jroesch merged pull request #7720:
URL: https://github.com/apache/tvm/pull/7720


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org