You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/06/19 06:51:40 UTC

[GitHub] [tvm] kevinLu1114 opened a new issue #8284: [ONNX] ssd-mobilenetv1 fail to build

kevinLu1114 opened a new issue #8284:
URL: https://github.com/apache/tvm/issues/8284


   The discussion I saw is at https://discuss.tvm.apache.org/t/failures-using-many-of-onnx-model-zoo-models/10268
   
   I used a script like https://gist.github.com/masahi/9348db919edb105912b94b84792dd7d3 to build ssd-mobilenetv1, but some errors appeared.
   
   tvm branch (commit 1fac10b3)
   llvm version; 12.0.1
   OS info: Ubuntu 20.10 (Groovy Gorilla)
   
   error message:
   ```
   ==> https://github.com/onnx/models/raw/master/vision/object_detection_segmentation/ssd-mobilenetv1/model/ssd_mobilenet_v1_10.tar.gz <==
   Loading ssd_mobilenet_v1/ssd_mobilenet_v1.onnx ...
   Input shapes: {'image_tensor:0': (1, 383, 640, 3)}
   Importing graph from ONNX to TVM Relay IR ...
   /home/chlu/tvm/python/tvm/relay/frontend/onnx.py:2572: UserWarning: 
                   Using scan outputs in a loop with strided slice
                   currently may cause errors during compilation.
                   
     warnings.warn(
   [14:48:48] ../src/runtime/threading_backend.cc:217: Warning: more than two frequencies detected!
   Compiling graph from Relay IR to llvm ...
   Caught an exception Traceback (most recent call last):
     37: TVMFuncCall
     36: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::relay::vm::VMCompiler::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
     35: tvm::relay::vm::VMCompiler::Lower(tvm::IRModule, tvm::runtime::Map<tvm::Integer, tvm::Target, void, void> const&, tvm::Target const&)
     34: tvm::relay::vm::VMCompiler::OptimizeModule(tvm::IRModule, tvm::runtime::Map<tvm::Integer, tvm::Target, void, void> const&, tvm::Target const&)
     33: tvm::transform::Pass::operator()(tvm::IRModule) const
     32: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
     31: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
     30: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
     29: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
     28: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::relay::Function (tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::AlterOpLayout()::{lambda(tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relay::transform::AlterOpLayout()::{lambda(tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
     27: tvm::relay::alter_op_layout::AlterOpLayout(tvm::RelayExpr const&)
     26: tvm::relay::ForwardRewrite(tvm::RelayExpr const&, tvm::runtime::TypedPackedFunc<tvm::RelayExpr (tvm::relay::Call const&, tvm::runtime::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)> const&, std::function<tvm::runtime::ObjectRef (tvm::relay::Call const&)>, std::function<tvm::RelayExpr (tvm::RelayExpr const&)>)
     25: tvm::relay::MixedModeMutator::VisitExpr(tvm::RelayExpr const&)
     24: tvm::relay::MixedModeMutator::VisitLeaf(tvm::RelayExpr const&)
     23: _ZN3tvm5relay16MixedModeMutator17DispatchVisitExprERKNS_9RelayExp
     22: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
     21: tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
     20: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_
     19: tvm::relay::ExprMutator::VisitExpr_(tvm::relay::FunctionNode const*)
     18: tvm::relay::MixedModeMutator::VisitExpr(tvm::RelayExpr const&)
     17: tvm::relay::MixedModeMutator::VisitLeaf(tvm::RelayExpr const&)
     16: _ZN3tvm5relay16MixedModeMutator17DispatchVisitExprERKNS_9RelayExp
     15: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
     14: tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
     13: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_
     12: tvm::relay::ExprMutator::VisitExpr_(tvm::relay::LetNode const*)
     11: tvm::relay::MixedModeMutator::VisitExpr(tvm::RelayExpr const&)
     10: tvm::relay::MixedModeMutator::VisitLeaf(tvm::RelayExpr const&)
     9: _ZN3tvm5relay16MixedModeMutator17DispatchVisitExprERKNS_9RelayExp
     8: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
     7: tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
     6: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_
     5: tvm::relay::MixedModeMutator::VisitExpr_(tvm::relay::CallNode const*)
     4: tvm::relay::ForwardRewriter::Rewrite_(tvm::relay::CallNode const*, tvm::RelayExpr const&)
     3: tvm::runtime::TypedPackedFunc<tvm::RelayExpr (tvm::relay::Call const&, tvm::runtime::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)>::AssignTypedLambda<tvm::RelayExpr (*)(tvm::relay::Call const&, tvm::runtime::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)>(tvm::RelayExpr (*)(tvm::relay::Call const&, tvm::runtime::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&))::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
     2: tvm::RelayExpr tvm::relay::LayoutRewriter<tvm::relay::alter_op_layout::AlterTransformMemorizer>(tvm::relay::Call const&, tvm::runtime::Array<tvm::RelayExpr, void> const&, tvm::runtime::ObjectRef const&)
     1: tvm::relay::alter_op_layout::AlterTransformMemorizer::CallWithNewLayouts(tvm::relay::Call const&, std::vector<tvm::RelayExpr, std::allocator<tvm::RelayExpr> > const&)
     0: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), TVMFuncCreateFromCFunc::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#2}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) [clone .cold]
     File "/home/chlu/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 81, in cfun
       rv = local_pyfunc(*pyargs)
     File "/home/chlu/tvm/python/tvm/relay/op/nn/_nn.py", line 195, in alter_op_layout_conv2d
       return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, out_type)
     File "<decorator-gen-58>", line 2, in conv2d_alter_layout
     File "/home/chlu/tvm/python/tvm/target/generic_func.py", line 275, in dispatch_func
       return dispatch_dict[k](*args, **kwargs)
     File "/home/chlu/tvm/python/tvm/topi/x86/conv2d_alter_op.py", line 60, in _alter_conv2d_layout
       impl, outs = relay.backend.compile_engine.select_implementation(
     File "/home/chlu/tvm/python/tvm/relay/backend/compile_engine.py", line 219, in select_implementation
       outs = impl.compute(attrs, inputs, out_type)
     File "/home/chlu/tvm/python/tvm/relay/op/op.py", line 125, in compute
       return _OpImplementationCompute(self, attrs, inputs, out_type)
     File "/home/chlu/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
       raise get_last_ffi_error()
     3: TVMFuncCall
     2: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::relay::__mk_TVM6::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
     1: tvm::relay::OpImplementation::Compute(tvm::Attrs const&, tvm::runtime::Array<tvm::te::Tensor, void> const&, tvm::Type const&)
     0: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), TVMFuncCreateFromCFunc::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#2}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) [clone .cold]
     File "/home/chlu/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 81, in cfun
       rv = local_pyfunc(*pyargs)
     File "/home/chlu/tvm/python/tvm/relay/op/strategy/generic.py", line 240, in _compute_conv2d
       return [topi_compute(*args)]
     File "/home/chlu/tvm/python/tvm/topi/x86/conv2d.py", line 129, in conv2d_nchw
       packed_out = conv2d_NCHWc(data, kernel, strides, padding, dilation, layout, layout, out_dtype)
     File "/home/chlu/tvm/python/tvm/autotvm/task/topi_integration.py", line 165, in wrapper
       node = topi_compute(cfg, *args)
     File "/home/chlu/tvm/python/tvm/topi/x86/conv2d.py", line 191, in conv2d_NCHWc
       oh = (ih - kernel_height + pt + pb) // sh + 1
   TypeError: unsupported operand type(s) for -: 'Any' and 'int'
   ```
   
   script
   ```python
   # Licensed to the Apache Software Foundation (ASF) under one
   # or more contributor license agreements.  See the NOTICE file
   # distributed with this work for additional information
   # regarding copyright ownership.  The ASF licenses this file
   # to you under the Apache License, Version 2.0 (the
   # "License"); you may not use this file except in compliance
   # with the License.  You may obtain a copy of the License at
   #
   #   http://www.apache.org/licenses/LICENSE-2.0
   #
   # Unless required by applicable law or agreed to in writing,
   # software distributed under the License is distributed on an
   # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
   # KIND, either express or implied.  See the License for the
   # specific language governing permissions and limitations
   # under the License.
   
   # See:
   # - https://tvm.apache.org/docs/tutorials/frontend/from_onnx.html
   # - https://github.com/apache/tvm/blob/main/tutorials/frontend/from_onnx.py
   # - https://github.com/onnx/models
   
   
   import subprocess
   import os
   import sys
   import posixpath
   from six.moves.urllib.request import urlretrieve
   import glob
   
   import onnx
   from onnx import numpy_helper
   import numpy as np
   import tvm
   import tvm.relay as relay
   from tvm.contrib import graph_executor
   from tvm.runtime.vm import VirtualMachine
   
   
   def get_value_info_shape(value_info):
       return tuple([max(d.dim_value, 1) for d in value_info.type.tensor_type.shape.dim])
   
   urls = [
       'https://github.com/onnx/models/raw/master/vision/object_detection_segmentation/ssd-mobilenetv1/model/ssd_mobilenet_v1_10.tar.gz',
   ]
   
   target = "cuda"
   
   ctx = tvm.device(target, 0)
   
   summary = []
   for url in urls:
       print(f'==> {url} <==')
   
       archive = posixpath.basename(url)
       if not os.path.exists(archive):
           print(f'Downloading {url} ...')
           urlretrieve(url, archive)
           assert os.path.exists(archive)
   
       import tarfile
       tar = tarfile.open(archive, 'r:gz')
       for n in tar.getnames():
           if n.endswith('.onnx'):
               model_file = n
               name = os.path.dirname(n)
               break
   
       if not os.path.exists(model_file):
           print(f'Extracting {archive} ...')
           #subprocess.call(['tar', 'xzf', archive])
           tar.extractall()
           assert os.path.exists(model_file)
   
       print(f'Loading {model_file} ...')
       onnx_model = onnx.load(model_file)
   
       graph = onnx_model.graph
   
       initializers = set()
       for initializer in graph.initializer:
           initializers.add(initializer.name)
   
       input_values = []
   
       test_data_set = glob.glob(os.path.join(name, 'test_data_set_*'))[0]
       shape_dict = {}
       assert os.path.exists(test_data_set)
       inputs = {}
       for input in graph.input:
           if input.name not in initializers:
               i = len(input_values)
               input_data = os.path.join(test_data_set, f'input_{i}.pb')
               tensor = onnx.TensorProto()
               input_data = open(input_data, 'rb').read()
               tensor.ParseFromString(input_data)
               x = numpy_helper.to_array(tensor)
               input_values.append(x)
               shape_dict[input.name] = x.shape
               inputs[input.name] = tvm.nd.array(x, ctx)
   
       print(f'Input shapes: {shape_dict}')
   
       try:
           print(f'Importing graph from ONNX to TVM Relay IR ...')
           mod, params = relay.frontend.from_onnx(onnx_model, shape_dict, freeze_params=True)
           mod = relay.transform.DynamicToStatic()(mod)
   
           print(f'Compiling graph from Relay IR to {target} ...')
           with tvm.transform.PassContext(opt_level=3):
               vm_exec = relay.vm.compile(mod, target, params=params)
   
           dev = tvm.device(target, 0)
           vm = VirtualMachine(vm_exec, dev)
           vm.set_input("main", **inputs)
   
           print(f"Running inference...")
           vm.run()
       except KeyboardInterrupt:
           raise
       except Exception as ex:
           print(f'Caught an exception {ex}')
           result = 'not ok'
       else:
           print(f'Succeeded!')
           result = 'ok'
       summary.append((result, url))
       print()
   
   print('Summary:')
   for result, url in summary:
       print(f'{result}\t- {url}')
   ```
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] kevinLu1114 edited a comment on issue #8284: [ONNX] ssd-mobilenetv1 fail to build

Posted by GitBox <gi...@apache.org>.
kevinLu1114 edited a comment on issue #8284:
URL: https://github.com/apache/tvm/issues/8284#issuecomment-864496636


   @masahi Thank you for your reply, I have also observed this,but I am not sure if it is because the ssd-mobilenet model originally needs dynamic shape(e.g predict bounding box).
   
   If you can, I'll help as much as possible, but probably limited ability because I was a beginner :)
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] kevinLu1114 commented on issue #8284: [ONNX] ssd-mobilenetv1 fail to build

Posted by GitBox <gi...@apache.org>.
kevinLu1114 commented on issue #8284:
URL: https://github.com/apache/tvm/issues/8284#issuecomment-864496636


   Thank you for your reply, I have also observed this,but I am not sure if it is because the ssd-mobilenet model originally needs dynamic shape(e.g predict bounding box).
   
   If you can, I'll help as much as possible, but probably limited ability because I was a beginner :)
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on issue #8284: [ONNX] ssd-mobilenetv1 fail to build

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #8284:
URL: https://github.com/apache/tvm/issues/8284#issuecomment-864467897


   This is a very strange model in that there are multiple ONNX `Loop` for no good reason. In particular, there is a loop at the beginning that does input image preprocessing, and for some reason the output of the loop is already dynamic in all dimensions. So the input to the first convolution op is already dynamic in H and W dimensions, which result in the error above.
   
   ```
   ...
   %37 = subtract(%36, meta[relay.Constant][5] /* ty=Tensor[(1, 1, 1, 1), float32] */) /* ty=Tensor[(?, ?, ?, ?), float32] */;
   %38 = nn.conv2d(%37, meta[relay.Constant][6] /* ty=Tensor[(32, 3, 3, 3), float32] */, strides=[2, 2], padding=[0, 0, 1, 1], kernel_size=[3, 3]) /* ty=Tensor[(?, 32, ?, ?), float32] */;
   ...
   ``` 
   
   I have a feeling that our ONNX `Loop` support does not preserve static shape information precisely, since it does not make sense to have a dynamic input at the first conv2d op after the preprocessing loop. Also this could be one of the reasons MaskRCNN import does not work well with ONNX, since it has a loop and compilation fails at dynamic H and W dimension which should not exist. @jwfromm @mbrookhart 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] kevinLu1114 commented on issue #8284: [ONNX] ssd-mobilenetv1 fail to build

Posted by GitBox <gi...@apache.org>.
kevinLu1114 commented on issue #8284:
URL: https://github.com/apache/tvm/issues/8284#issuecomment-864405369


   ssd-mobilenetv1 model :https://github.com/onnx/models/tree/master/vision/object_detection_segmentation/ssd-mobilenetv1


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org