You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/03/22 21:08:26 UTC

[GitHub] [tvm] electriclilies commented on a change in pull request #7720: [ONNX] Onnx node tests

electriclilies commented on a change in pull request #7720:
URL: https://github.com/apache/tvm/pull/7720#discussion_r599033812



##########
File path: python/tvm/relay/frontend/onnx.py
##########
@@ -444,9 +448,15 @@ class ConvTranspose(OnnxOpConverter):
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
         # get number of channels
-        channels = infer_channels(inputs[1], True)
+        out_type = infer_type(inputs[1])
+        out_shapes = [get_const_tuple(out_type.checked_type.shape)]
+        channels = out_shapes[0][1]

Review comment:
       Does this need to work for layouts other than NCHW? It looks like the ONNX op doesn't specify layout in the ConvTranspose operator

##########
File path: tests/python/frontend/onnx/test_forward.py
##########
@@ -4090,6 +4090,170 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"):
     verify_cumsum(data, 1, 1, 1, type="int32")
 
 
+from onnx import numpy_helper
+
+f = onnx.__file__
+import glob
+
+onnx_test_folders = sorted(glob.glob("/".join(f.split("/")[0:-1]) + "/backend/test/data/node/*/"))
+
+unsupported_onnx_tests = [
+    "test_basic_convinteger/",
+    "test_bitshift_left_uint16/",
+    "test_bitshift_left_uint32/",
+    "test_bitshift_left_uint64/",
+    "test_bitshift_left_uint8/",
+    "test_bitshift_right_uint16/",
+    "test_bitshift_right_uint32/",
+    "test_bitshift_right_uint64/",
+    "test_bitshift_right_uint8/",
+    "test_cast_DOUBLE_to_FLOAT16/",
+    "test_cast_FLOAT16_to_DOUBLE/",
+    "test_cast_FLOAT16_to_FLOAT/",
+    "test_cast_FLOAT_to_FLOAT16/",
+    "test_cast_FLOAT_to_STRING/",
+    "test_cast_STRING_to_FLOAT/",
+    "test_compress_0/",
+    "test_compress_1/",
+    "test_compress_default_axis/",
+    "test_compress_negative_axis/",
+    "test_convinteger_with_padding/",
+    "test_convtranspose_dilations/",
+    "test_convtranspose_output_shape/",
+    "test_cumsum_1d/",
+    "test_cumsum_1d_exclusive/",
+    "test_cumsum_1d_reverse/",
+    "test_cumsum_1d_reverse_exclusive/",
+    "test_cumsum_2d_axis_0/",
+    "test_cumsum_2d_axis_1/",
+    "test_cumsum_2d_negative_axis/",
+    "test_dequantizelinear/",
+    "test_det_2d/",
+    "test_det_nd/",
+    "test_dynamicquantizelinear/",
+    "test_dynamicquantizelinear_expanded/",
+    "test_dynamicquantizelinear_max_adjusted/",
+    "test_dynamicquantizelinear_max_adjusted_expanded/",
+    "test_dynamicquantizelinear_min_adjusted/",
+    "test_dynamicquantizelinear_min_adjusted_expanded/",
+    "test_eyelike_populate_off_main_diagonal/",
+    "test_eyelike_with_dtype/",
+    "test_eyelike_without_dtype/",
+    "test_hardmax_axis_0/",
+    "test_hardmax_axis_1/",
+    "test_hardmax_axis_2/",
+    "test_hardmax_default_axis/",
+    "test_hardmax_example/",
+    "test_hardmax_negative_axis/",
+    "test_hardmax_one_hot/",
+    "test_isinf_negative/",
+    "test_isinf_positive/",
+    "test_lstm_defaults/",
+    "test_lstm_with_initial_bias/",
+    "test_lstm_with_peepholes/",
+    "test_matmulinteger/",
+    "test_maxpool_2d_dilations/",
+    "test_maxpool_2d_same_lower/",
+    "test_maxpool_2d_same_upper/",
+    "test_maxpool_with_argmax_2d_precomputed_pads/",
+    "test_maxpool_with_argmax_2d_precomputed_strides/",
+    "test_maxunpool_export_with_output_shape/",
+    "test_mvn/",
+    "test_nonmaxsuppression_center_point_box_format/",
+    "test_qlinearconv/",
+    "test_qlinearmatmul_2D/",
+    "test_qlinearmatmul_3D/",
+    "test_quantizelinear/",
+    "test_range_float_type_positive_delta_expanded/",
+    "test_range_int32_type_negative_delta_expanded/",
+    "test_resize_downsample_scales_cubic/",
+    "test_resize_downsample_scales_cubic_A_n0p5_exclude_outside/",
+    "test_resize_downsample_scales_cubic_align_corners/",
+    "test_resize_downsample_scales_linear/",
+    "test_resize_downsample_scales_nearest/",
+    "test_resize_downsample_sizes_cubic/",
+    "test_resize_downsample_sizes_linear_pytorch_half_pixel/",
+    "test_resize_downsample_sizes_nearest/",
+    "test_resize_downsample_sizes_nearest_tf_half_pixel_for_nn/",
+    "test_resize_tf_crop_and_resize/",
+    "test_resize_upsample_scales_cubic/",
+    "test_resize_upsample_scales_cubic_A_n0p5_exclude_outside/",
+    "test_resize_upsample_scales_cubic_align_corners/",
+    "test_resize_upsample_scales_cubic_asymmetric/",
+    "test_resize_upsample_scales_linear/",
+    "test_resize_upsample_sizes_cubic/",
+    "test_resize_upsample_sizes_nearest_ceil_half_pixel/",
+    "test_resize_upsample_sizes_nearest_floor_align_corners/",
+    "test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric/",
+    "test_reversesequence_batch/",
+    "test_reversesequence_time/",
+    "test_rnn_seq_length/",
+    "test_roialign/",
+    "test_round/",
+    "test_scan9_sum/",
+    "test_scan_sum/",
+    "test_scatternd/",
+    "test_selu_default/",
+    "test_shrink_hard/",
+    "test_shrink_soft/",
+    "test_simple_rnn_defaults/",
+    "test_simple_rnn_with_initial_bias/",
+    "test_slice_neg_steps/",
+    "test_slice_start_out_of_bounds/",
+    "test_strnormalizer_export_monday_casesensintive_lower/",
+    "test_strnormalizer_export_monday_casesensintive_nochangecase/",
+    "test_strnormalizer_export_monday_casesensintive_upper/",
+    "test_strnormalizer_export_monday_empty_output/",
+    "test_strnormalizer_export_monday_insensintive_upper_twodim/",
+    "test_strnormalizer_nostopwords_nochangecase/",
+    "test_tfidfvectorizer_tf_batch_onlybigrams_skip0/",
+    "test_tfidfvectorizer_tf_batch_onlybigrams_skip5/",
+    "test_tfidfvectorizer_tf_batch_uniandbigrams_skip5/",
+    "test_tfidfvectorizer_tf_only_bigrams_skip0/",
+    "test_tfidfvectorizer_tf_onlybigrams_levelempty/",
+    "test_tfidfvectorizer_tf_onlybigrams_skip5/",
+    "test_tfidfvectorizer_tf_uniandbigrams_skip5/",
+    "test_top_k_smallest/",
+    "test_unique_not_sorted_without_axis/",
+    "test_unique_sorted_with_axis/",
+    "test_unique_sorted_with_axis_3d/",
+    "test_unique_sorted_with_negative_axis/",
+    "test_unique_sorted_without_axis/",
+    "test_unsqueeze_unsorted_axes/",
+    "test_upsample_nearest/",
+]
+
+
+@pytest.mark.parametrize("test", onnx_test_folders)
+def test_onnx_nodes(test):
+    for failure in unsupported_onnx_tests:
+        if failure in test:
+            pytest.skip()
+            break
+    onnx_model = onnx.load(test + "/model.onnx")
+    inputs = []
+    outputs = []
+    for dataset in glob.glob(test + "/*/"):
+        tensors = sorted(glob.glob(dataset + "/*.pb"))
+        for tensor in tensors:
+            new_tensor = onnx.TensorProto()
+            with open(tensor, "rb") as f:
+                new_tensor.ParseFromString(f.read())
+            if "input" in tensor.split("/")[-1]:
+                inputs.append(numpy_helper.to_array(new_tensor))
+            elif "output" in tensor.split("/")[-1]:
+                outputs.append(numpy_helper.to_array(new_tensor))
+            else:
+                print(tensor)
+                raise

Review comment:
       Can you put an error message here? Maybe something like "Expected tensor to be either an input or output"




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org