You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/01/15 00:39:37 UTC

[GitHub] [tvm] ayeganov opened a new issue #9939: [Bug] Internal Error With Target == "Metal"

ayeganov opened a new issue #9939:
URL: https://github.com/apache/tvm/issues/9939


   I was following along the tutorial for converting the resnet 50 network using `tvmc`, but hit a hard stop with the following message:
   
   libc++abi: terminating with uncaught exception of type tvm::runtime::InternalError
   
   ### Expected behavior
   
   A tar file to be generated for running the network
   
   ### Actual behavior
   
   Process gets aborted
   
   ### Environment
   
   OS: MacOS Monterey 12.1, Darwin AY-M-D6ZQ 21.2.0 Darwin Kernel Version 21.2.0: Sun Nov 28 20:28:54 PST 2021; root:xnu-8019.61.5~1/RELEASE_X86_64 x86_64
   TVM: 0.9.dev0
   Graphics card: Intel(R) UHD Graphics 630, AMD Radeon Pro 555X
   
   ### Steps to reproduce
   
   import numpy as np
   import math
   
   import os
   import tvm
   from tvm import relay, auto_scheduler
   from tvm import testing
   from tvm.contrib import utils, xcode, coreml_runtime, graph_runtime
   
   target = "metal"
   target_host = "llvm -mtriple=arm64-apple-darwin20.5.0"
   
   
   def _get_model(shape, dtype, var_names):
       """Return a model and any parameters it may have."""
       a = relay.var(next(var_names), shape=shape, dtype=dtype)
       out = relay.op.reduce.mean(a, 0)
       params = {}
       return out, params
   
   
   def converter(shape):
       print("Shape: {}".format(shape))
       dtype = "float32"
       # b, data
       inputs = {"data": tvm.nd.array(np.random.uniform(-128, 127, shape).astype(dtype))}
       mod, params = _get_model(shape, dtype, iter(inputs))
       if isinstance(mod, tvm.relay.expr.Call):
           mod = tvm.IRModule.from_expr(mod)
       print('mod: ', mod)
       with tvm.transform.PassContext(opt_level=3):
           graph_module = relay.build(mod['main'], target=target, target_host=target_host, params=params)
       #with auto_scheduler.ApplyHistoryBest("my_mean_model_metal"):
       #    with tvm.transform.PassContext(opt_level=3, config={"relay.backend.use_auto_scheduler": True}):
       #        graph_module = relay.build(mod['main'], target=target, target_host=target_host, params=params)
       return graph_module
   
   
   def run(graph_module):
       ctx = tvm.metal(0)
       m = graph_runtime.graph_executor.GraphModule(graph_module["default"](ctx))
       m.run()
   
   
   if __name__ == "__main__":
       shape = (2, 1)
       gm = converter(shape)
       run(gm)
   
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] ayeganov commented on issue #9939: [Bug] Internal Error With Target == "Metal"

Posted by GitBox <gi...@apache.org>.
ayeganov commented on issue #9939:
URL: https://github.com/apache/tvm/issues/9939#issuecomment-1014867710


   @masahi thank you so much for your answers and time. If you are on patreon or some open source donations platform, drop the link here so I can donate toward your projects. That said, after looking into this for a little while I was able to achieve what I need to do in Python, as my testing ground. Here is roughly what I ended up doing:
   
   ```
   onnx_model = onnx.load(model_path)
   input_shape = (1, 4, 256, 144)  # I don't care about dynamic batches, since my case is pure inference of camera frames
   shape_dict = {input_name: input_shape}
   mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)
   
   target = "metal"
   with tvm.transform.PassContext(opt_level=1):   
       executor = relay.build_module.create_executor(   
           "graph", mod, tvm.metal(1), target, params   
       ).evaluate()                  
   
     while True:
         res = executor(tvm.nd.array(final_img, tvm.metal(1)))
         tvm_output = res
         time.sleep(0.016)
   ```
   
   Since generally VMs come with a performance hit I like that I am able to avoid it in this case. What I essentially want is to repeat these steps in C++ using `tvm` as a 3rd party library with the minimal number of headers. As my starting point I used `apps/howto_deploy/cpp_deploy.cc`, but modifying the code to build for Metal `howto_deploy/prepare_test_libs.py:47` and then running on Metal device `apps/howto_deploy/cpp_deploy.cc:89` yielded an error that Metal loader is not registered. Looking in `apps/howto_deploy/tvm_runtime_pack.cc` leads to lines 79 and 80 for enabling Metal, which are ObjC files, at which point the example fails to build due to ObjC syntax in `/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/System/Library/Frameworks/Foundation.framework/Headers/NSObjCRuntime.h:523`. So my question at this point is - can Metal be done in C++ at all?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] ayeganov commented on issue #9939: [Bug] Internal Error With Target == "Metal"

Posted by GitBox <gi...@apache.org>.
ayeganov commented on issue #9939:
URL: https://github.com/apache/tvm/issues/9939#issuecomment-1013682255


   @masahi I actually need to consume the model in c++, so I'm assuming all of the above is possible with c++ API as well. Also, which error were you referring to - the original internal error or symbol shape one? 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] ayeganov edited a comment on issue #9939: [Bug] Internal Error With Target == "Metal"

Posted by GitBox <gi...@apache.org>.
ayeganov edited a comment on issue #9939:
URL: https://github.com/apache/tvm/issues/9939#issuecomment-1013786357


   @masahi Thank you for the helpful answers. I am going to go and try compiling the model with the VM. But I think I still managed to to be vague enough with my questions to confuse myself with your answers :). Here is what I want to accomplish and currently think I understand:
   
   1. I need to take an existing ONNX model, that performs well on CPU, but uses too many computation resources for me to deploy it in the wild as is, and convert it to TVM runtime utilizing Metal as hardware accelerator.
   2. Take the converted model and load it into my own library in C++
   
   I don't care about compiling the model to TVM in C++, as long as I can run it in C++. In fact, I'd prefer to do it in python, because it is easier to script. I did notice this comment in the referenced implementation you linked:
   
   ```
   # Compile with Relay VM
   # ---------------------
   # Note: Currently only CPU target is supported. For x86 target, it is
   # highly recommended to build TVM with Intel MKL and Intel OpenMP to get
   # best performance, due to the existence of large dense operator in
   # torchvision rcnn models.
   ```
   
   Is this outdated? I'll give this a shot in a few minutes, but wanted to bring that to your attention in case there is some discrepancy between examples and actual implemented functionality.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on issue #9939: [Bug] Internal Error With Target == "Metal"

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #9939:
URL: https://github.com/apache/tvm/issues/9939#issuecomment-1015175892


   This presentation may be of interest to you https://www.youtube.com/watch?v=9fDR9upXZTI


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on issue #9939: [Bug] Internal Error With Target == "Metal"

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #9939:
URL: https://github.com/apache/tvm/issues/9939#issuecomment-1013754272


   The error is ` Check failed: (pval != nullptr) is false: Cannot allocate memory symbolic tensor shape [?, 1280, 9, 16]`. When you do `relay.build`, we use something called graph codegen / runtime, where all shapes need to be statically known. The `?` indicates your model has a dynamic batch dimension somewhere, so you need to use a compiler and runtime that supports dynamic shape. That's what the VM compiler and runtime for.
   
   
   
   > @masahi does VM compiler support metal?
   
   Yes, there is no coupling between targets and the VM runtime, all targets work.
   
   > I actually need to consume the model in c++, so I'm assuming all of the above is possible with c++ API as well
   
   In principle yes. Deployment in C++ is common, but less so for VM. There is an example in https://github.com/apache/tvm/blob/e7024fb39ea27494fa5618102dae42e7e5551986/src/contrib/torch/pt_call_tvm/tvm_class.cc#L143. On the other hand, I've never seen someone doing VM compilation in C++ (equivalent of doing `vm.compile(...)`). By `consume` I hope you meant the former use case.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on issue #9939: [Bug] Internal Error With Target == "Metal"

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #9939:
URL: https://github.com/apache/tvm/issues/9939#issuecomment-1013608863


   This error means you need to use the VM compiler and runtime for this model. See https://github.com/apache/tvm/blob/main/gallery/how_to/deploy_models/deploy_object_detection_pytorch.py#L130-L139


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi closed issue #9939: [Bug] Internal Error With Target == "Metal"

Posted by GitBox <gi...@apache.org>.
masahi closed issue #9939:
URL: https://github.com/apache/tvm/issues/9939


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on issue #9939: [Bug] Internal Error With Target == "Metal"

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #9939:
URL: https://github.com/apache/tvm/issues/9939#issuecomment-1013799936


   > Is this outdated?
   
   No, this is specific to the model used in that tutorial (MaskRCNN) and not about VM per se. 
   
   
   
   > 2\. Compiling with VM is going to bake it into the model and running it with C++ won't require me to use a VM
   
   You still need the VM runtime to run the model in your C++ app. The linked example https://github.com/apache/tvm/blob/e7024fb39ea27494fa5618102dae42e7e5551986/src/contrib/torch/pt_call_tvm/tvm_class.cc#L143 shows how to use VM runtime C++ API.
   
   > Also, your examples all point to PyTorch, is that because they are good examples and I can do everything they do with ONNX as well, or should I use PyTorch to achieve the functionality I need?
   
   It doesn't matter, after a model is imported via PyTorch or ONNX, the compilation and deployment flow are exactly the same. It's just that I'm more familiar with PyTorch.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on issue #9939: [Bug] Internal Error With Target == "Metal"

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #9939:
URL: https://github.com/apache/tvm/issues/9939#issuecomment-1014913956


   I'm not familiar with metal / objc. Maybe nobody has done C++ deployment with metal before.
   
   I suggest opening a thread at https://discuss.tvm.apache.org/ for more visibility. And please send a PR if you find a solution.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi edited a comment on issue #9939: [Bug] Internal Error With Target == "Metal"

Posted by GitBox <gi...@apache.org>.
masahi edited a comment on issue #9939:
URL: https://github.com/apache/tvm/issues/9939#issuecomment-1013754272


   The error is ` Check failed: (pval != nullptr) is false: Cannot allocate memory symbolic tensor shape [?, 1280, 9, 16]`. When you do `relay.build`, we use something called graph codegen / runtime, where all shapes need to be statically known. The `?` indicates your model has a dynamic batch dimension somewhere, so you need to use a compiler and runtime that supports dynamic shape. That's what the VM compiler and runtime for.
   
   
   
   > @masahi does VM compiler support metal?
   
   Yes, there is no coupling between targets and the VM runtime, all targets work.
   
   > I actually need to consume the model in c++, so I'm assuming all of the above is possible with c++ API as well
   
   In principle yes. Deployment in C++ graph runtime is common, but less so for VM. There is an example in https://github.com/apache/tvm/blob/e7024fb39ea27494fa5618102dae42e7e5551986/src/contrib/torch/pt_call_tvm/tvm_class.cc#L143. On the other hand, I've never seen someone doing VM compilation in C++ (equivalent of doing `vm.compile(...)`). By `consume` I hope you meant the former use case.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] ayeganov commented on issue #9939: [Bug] Internal Error With Target == "Metal"

Posted by GitBox <gi...@apache.org>.
ayeganov commented on issue #9939:
URL: https://github.com/apache/tvm/issues/9939#issuecomment-1013565196


   Oops, sorry - no. I'll update the description, this snuck from my tests in the command line.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] ayeganov commented on issue #9939: [Bug] Internal Error With Target == "Metal"

Posted by GitBox <gi...@apache.org>.
ayeganov commented on issue #9939:
URL: https://github.com/apache/tvm/issues/9939#issuecomment-1013593029


   When I changed the version of tvm to 0.8.0 and tried compiling my own model I got the following error:
   
   ```
   [21:39:35] /Users/ayeganov/code/cisco/tvm_exploration/tvm/src/target/target_kind.cc:164: Warning: Unable to detect CUDA version, default to "-mcpu=sm_20" instead
   [21:39:35] /Users/ayeganov/code/cisco/tvm_exploration/tvm/src/target/target_kind.cc:190: Warning: Unable to detect ROCm compute arch, default to "-mcpu=gfx900" instead
   [21:39:35] /Users/ayeganov/code/cisco/tvm_exploration/tvm/src/target/target_kind.cc:204: Warning: Unable to detect ROCm version, assuming >= 3.5
   [21:39:35] /Users/ayeganov/code/cisco/tvm_exploration/tvm/src/target/target_kind.cc:164: Warning: Unable to detect CUDA version, default to "-mcpu=sm_20" instead
   [21:39:35] /Users/ayeganov/code/cisco/tvm_exploration/tvm/src/target/target_kind.cc:190: Warning: Unable to detect ROCm compute arch, default to "-mcpu=gfx900" instead
   [21:39:35] /Users/ayeganov/code/cisco/tvm_exploration/tvm/src/target/target_kind.cc:204: Warning: Unable to detect ROCm version, assuming >= 3.5
   /Users/ayeganov/.pyenv/versions/3.8.6/envs/tvm_for_ladon/lib/python3.8/site-packages/tvm-0.8.0-py3.8-macosx-10.15-x86_64.egg/tvm/relay/frontend/onnx.py:4478: UserWarning: Input encoder_output_1280x9x16 has unknown dimension shapes: ['batch_size', 1280, 9, 16]. Specifying static values may improve performance
     warnings.warn(warning_msg)
   [21:39:35] /Users/ayeganov/code/cisco/tvm_exploration/tvm/src/target/target_kind.cc:164: Warning: Unable to detect CUDA version, default to "-mcpu=sm_20" instead
   [21:39:35] /Users/ayeganov/code/cisco/tvm_exploration/tvm/src/target/target_kind.cc:190: Warning: Unable to detect ROCm compute arch, default to "-mcpu=gfx900" instead
   [21:39:35] /Users/ayeganov/code/cisco/tvm_exploration/tvm/src/target/target_kind.cc:204: Warning: Unable to detect ROCm version, assuming >= 3.5
   Traceback (most recent call last):
     File "/Users/ayeganov/.pyenv/versions/3.8.6/Python.framework/Versions/3.8/lib/python3.8/runpy.py", line 194, in _run_module_as_main
       return _run_code(code, main_globals, None,
     File "/Users/ayeganov/.pyenv/versions/3.8.6/Python.framework/Versions/3.8/lib/python3.8/runpy.py", line 87, in _run_code
       exec(code, run_globals)
     File "/Users/ayeganov/.pyenv/versions/3.8.6/envs/tvm_for_ladon/lib/python3.8/site-packages/tvm-0.8.0-py3.8-macosx-10.15-x86_64.egg/tvm/driver/tvmc/__main__.py", line 24, in <module>
       tvmc.main.main()
     File "/Users/ayeganov/.pyenv/versions/3.8.6/envs/tvm_for_ladon/lib/python3.8/site-packages/tvm-0.8.0-py3.8-macosx-10.15-x86_64.egg/tvm/driver/tvmc/main.py", line 94, in main
       sys.exit(_main(sys.argv[1:]))
     File "/Users/ayeganov/.pyenv/versions/3.8.6/envs/tvm_for_ladon/lib/python3.8/site-packages/tvm-0.8.0-py3.8-macosx-10.15-x86_64.egg/tvm/driver/tvmc/main.py", line 87, in _main
       return args.func(args)
     File "/Users/ayeganov/.pyenv/versions/3.8.6/envs/tvm_for_ladon/lib/python3.8/site-packages/tvm-0.8.0-py3.8-macosx-10.15-x86_64.egg/tvm/driver/tvmc/compiler.py", line 141, in drive_compile
       compile_model(
     File "/Users/ayeganov/.pyenv/versions/3.8.6/envs/tvm_for_ladon/lib/python3.8/site-packages/tvm-0.8.0-py3.8-macosx-10.15-x86_64.egg/tvm/driver/tvmc/compiler.py", line 271, in compile_model
       graph_module = relay.build(mod, target=tvm_target, params=params)
     File "/Users/ayeganov/.pyenv/versions/3.8.6/envs/tvm_for_ladon/lib/python3.8/site-packages/tvm-0.8.0-py3.8-macosx-10.15-x86_64.egg/tvm/relay/build_module.py", line 369, in build
       executor_config, runtime_mod, params = bld_mod.build(
     File "/Users/ayeganov/.pyenv/versions/3.8.6/envs/tvm_for_ladon/lib/python3.8/site-packages/tvm-0.8.0-py3.8-macosx-10.15-x86_64.egg/tvm/relay/build_module.py", line 177, in build
       self._build(mod, target, target_host, executor, mod_name)
     File "/Users/ayeganov/.pyenv/versions/3.8.6/envs/tvm_for_ladon/lib/python3.8/site-packages/tvm-0.8.0-py3.8-macosx-10.15-x86_64.egg/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
       raise get_last_ffi_error()
   tvm._ffi.base.TVMError: Traceback (most recent call last):
     [bt] (8) 9   libtvm.dylib                        0x000000010aa456d6 tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&) + 214
     [bt] (7) 8   libtvm.dylib                        0x000000010aa459df tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>*) const + 303
     [bt] (6) 7   libtvm.dylib                        0x000000010aa5e1d3 tvm::relay::transform::DeviceAwareExprVisitor::VisitExpr_(tvm::relay::FunctionNode const*) + 291
     [bt] (5) 6   libtvm.dylib                        0x000000010ac023db tvm::relay::StorageAllocaBaseVisitor::DeviceAwareVisitExpr_(tvm::relay::FunctionNode const*) + 171
     [bt] (4) 5   libtvm.dylib                        0x000000010ac0360c tvm::relay::StorageAllocaBaseVisitor::CreateToken(tvm::RelayExprNode const*, bool) + 156
     [bt] (3) 4   libtvm.dylib                        0x000000010ac0329b tvm::relay::StorageAllocator::CreateTokenOnDevice(tvm::RelayExprNode const*, DLDeviceType, bool) + 1755
     [bt] (2) 3   libtvm.dylib                        0x000000010ac06527 tvm::relay::StorageAllocator::GetMemorySize(tvm::relay::StorageToken*) + 487
     [bt] (1) 2   libtvm.dylib                        0x0000000109975479 tvm::runtime::detail::LogFatal::Entry::Finalize() + 89
     [bt] (0) 1   libtvm.dylib                        0x000000010ae2a068 tvm::runtime::Backtrace() + 24
     File "/Users/ayeganov/code/cisco/tvm_exploration/tvm/src/relay/backend/graph_plan_memory.cc", line 408
   TVMError: 
   ---------------------------------------------------------------
   An error occurred during the execution of TVM.
   For more information, please see: https://tvm.apache.org/docs/errors.html
   ---------------------------------------------------------------
     Check failed: (pval != nullptr) is false: Cannot allocate memory symbolic tensor shape [?, 1280, 9, 16]
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] electriclilies commented on issue #9939: [Bug] Internal Error With Target == "Metal"

Posted by GitBox <gi...@apache.org>.
electriclilies commented on issue #9939:
URL: https://github.com/apache/tvm/issues/9939#issuecomment-1013564486


   Is gesture_model.onnx the resnet50 referenced in the tutorial?
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] ayeganov edited a comment on issue #9939: [Bug] Internal Error With Target == "Metal"

Posted by GitBox <gi...@apache.org>.
ayeganov edited a comment on issue #9939:
URL: https://github.com/apache/tvm/issues/9939#issuecomment-1013786357


   @masahi Thank you for the helpful answers. I am going to go and try compiling the model with the VM. But I think I still managed to to be vague enough with my questions to confuse myself with your answers :). Here is what I want to accomplish and currently think I understand:
   
   1. I need to take an existing ONNX model, that performs well on CPU, but uses too many computation resources for me to deploy it in the wild as is, and convert it to TVM runtime utilizing Metal as hardware accelerator.
   2. Take the converted model and load it into my own library in C++
   
   I don't care about compiling the model to TVM in C++, as long as I can run it in C++. In fact, I'd prefer to do it in python, because it is easier to script. I did notice this comment in the referenced implementation you linked:
   
   ```
   # Compile with Relay VM
   # ---------------------
   # Note: Currently only CPU target is supported. For x86 target, it is
   # highly recommended to build TVM with Intel MKL and Intel OpenMP to get
   # best performance, due to the existence of large dense operator in
   # torchvision rcnn models.
   ```
   
   Is this outdated? I'll give this a shot in a few minutes, but wanted to bring that to your attention in case there is some discrepancy between examples and actual implemented functionality.
   
   In essence, what I think I understand:
   
   1. Models can compiled by Python and used by TVM API in either C++ or Python
   2. Compiling with VM is going to bake it into the model and running it with C++ won't require me to use a VM
   
   Also, your examples all point to PyTorch, is that because they are good examples and I can do everything they do with ONNX as well, or should I use PyTorch to achieve the functionality I need?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] ayeganov edited a comment on issue #9939: [Bug] Internal Error With Target == "Metal"

Posted by GitBox <gi...@apache.org>.
ayeganov edited a comment on issue #9939:
URL: https://github.com/apache/tvm/issues/9939#issuecomment-1013786357


   @masahi Thank you for the helpful answers. I am going to go and try compiling the model with the VM. But I think I still managed to to be vague enough with my questions to confuse myself with your answers :). Here is what I want to accomplish and currently think I understand:
   
   1. I need to take an existing ONNX model, that performs well on CPU, but uses too many computation resources for me to deploy it in the wild as is, and convert it to TVM runtime utilizing Metal as hardware accelerator.
   2. Take the converted model and load it into my own library in C++
   
   I don't care about compiling the model to TVM in C++, as long as I can run it in C++. In fact, I'd prefer to do it in python, because it is easier to script. I did notice this comment in the referenced implementation you linked:
   
   ```
   # Compile with Relay VM
   # ---------------------
   # Note: Currently only CPU target is supported. For x86 target, it is
   # highly recommended to build TVM with Intel MKL and Intel OpenMP to get
   # best performance, due to the existence of large dense operator in
   # torchvision rcnn models.
   ```
   
   Is this outdated? I'll give this a shot in a few minutes, but wanted to bring that to your attention in case there is some discrepancy between examples and actual implemented functionality.
   
   In essence, what I think I understand:
   
   1. Models can be compiled by Python and used by TVM API in either C++ or Python
   2. Compiling with VM is going to bake it into the model and running it with C++ won't require me to use a VM
   
   Also, your examples all point to PyTorch, is that because they are good examples and I can do everything they do with ONNX as well, or should I use PyTorch to achieve the functionality I need?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] ayeganov commented on issue #9939: [Bug] Internal Error With Target == "Metal"

Posted by GitBox <gi...@apache.org>.
ayeganov commented on issue #9939:
URL: https://github.com/apache/tvm/issues/9939#issuecomment-1013786357


   @masahi Thank you for the helpful answers. I am going to go and try compiling the model with the VM. But I think I still managed to to be vague enough with my questions to confuse myself with your answers :). Here is what I want to accomplish and currently think I understand:
   
   1. I need to take an existing ONNX model, that performs well on CPU, but uses too many computation resources for me to deploy it in the wild as is, and convert it to TVM runtime utilizing Metal as hardware accelerator.
   2. Take the converted model and load it into my own library in C++
   
   I don't care about compiling the model to TVM in C++, as long as I can run it in C++. In fact, I'd prefer to do it in python, because it is easier to script. I did notice this comment in the referenced implementation you linked:
   
   # Compile with Relay VM
   # ---------------------
   # Note: Currently only CPU target is supported. For x86 target, it is
   # highly recommended to build TVM with Intel MKL and Intel OpenMP to get
   # best performance, due to the existence of large dense operator in
   # torchvision rcnn models.
   
   Is this outdated? I'll give this a shot in a few minutes, but wanted to bring that to your attention in case there is some discrepancy between examples and actual implemented functionality.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] ayeganov commented on issue #9939: [Bug] Internal Error With Target == "Metal"

Posted by GitBox <gi...@apache.org>.
ayeganov commented on issue #9939:
URL: https://github.com/apache/tvm/issues/9939#issuecomment-1013682565


   @masahi does VM compiler support metal? 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org