You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@tvm.apache.org by Yao Wang <no...@github.com> on 2019/10/14 04:57:35 UTC

[dmlc/tvm] [RFC] Dynamic Shape Support - Graph Dispatching (#4118)

## Overview
There are more and more deployment requirements regarding dynamic input graphs, such as dynamic batching CV models and dynamic BERT. While dynamic input graph is supported in eager mode(Pytorch, Tensorflow Eager, MXNet Gluon) for model developing, TVM still just support static shape input models. In this thread I'll discuss about possible solution for dynamic shape AOT compilation.

Let's start by considering supporting a single operator with dynamic shape. TVM has already supported tensor expression with symbolic variable well, which means we have no difficulty in expressing a dynamic shape kernel with existing compute and schedule system. However, a single schedule cannot achieve desired performance for all possible values for a symbolic axis. For example, a dynamic batch conv2d on cuda can require quite different values of block_z and thread_z for different batch sizes. A possible method to solve this problem is to split symbolic axes into several buckets:
![dg1 (1)](https://user-images.githubusercontent.com/15520525/66729153-672d2a00-edfe-11e9-99f9-56b059e41b3b.png)
For each bucket, we select a representative kernel which performs well in the corresponding range for symbolic axis.
In this thread, I won't focus on this topic and @icemelon9 @comaniac @sxjscience will dive deep into this issue in other threads.

In this thread, we will discuss graph dispatching for dynamic shape. Bucketing method for kernel works well in runtime for operators which doesn’t require layout transformation, such as dense and batch_matmul(as for today's tvm implementation). However, in computer vision models, conv2d usually requires layout transformation to achieve better performance. Two issues raise to use kernel dispatch function in runtime:
1. A generic layout transform function and a runtime layout tracking system are needed, which introduces a lot of complexity.
2. Graph tuning is not well defined if kernel dispatch function is used, which brings performance degradation.

To resolve these issues, instead of kernel dispatch function, we use **graph dispatch function** which splits input shape of the whole graph into buckets and clone a graph for each bucket:
![dg2 (1)](https://user-images.githubusercontent.com/15520525/66729331-955f3980-edff-11e9-9069-eb5e4fce4bb1.png)
Graph dispatch function is a nested IfThenElse statement block which selects copy of graph depending on actual input shape. Thanks to the functional nature of relay, we can easily create global function in relay module to represent different clone of graph, and share parameters through function call. These are advantages for graph dispatch function:

1. Modifications are done on relay level by inserting dispatch function and call original graph function. No extra change to VM runtime is required.(Though kernel shape function is required anyway)
2. Parameter sharing is naturally achieved by function call. No runtime change is required.
3. Graph tuning can be done for each copy of graph and no extra layout tracking system is required.
4. Autotvm dispatch context ApplyGraphBest can be easily extended to support this feature.

## **API**
We will add a new member function Dispatch to Relay Module:
```c++
void Dispatch(const std::string& func_name, const InputShapeDict& input_shape, const PackedFunc& dispatch_func);
```
This function update a global function inside module to be a dispatching block followed by copied functions. dispatch_func decides how to generate buckets.

### Dispatch Function
Dispatch function is a function from an input shape dictionary to a map from input name to a map from symbolic axis index to list of intervals. For example, for input shape dictionary which represents a CV model allowing arbitrary image sizes: 
```python
{"data": (1, 3, tvm.relay.Any(), tvm.relay.Any())}
```
A logarithmical dispatch function returns a dictionary:
```python
{
  "data":
      {
          2: [(1, 2), (2, 4), (4, 8), (8, 16), (16, 32), (32, 64), (64, 128), (128, 256), (256, None)],
          3: [(1, 2), (2, 4), (4, 8), (8, 16), (16, 32), (32, 64), (64, 128), (128, 256), (256, None)],
      }
}
```
As a result, in the final main function there will be 9 * 9 = 81 copies of original graph. Here introduces a tradeoff between overall performance and number of function kernels.

We will provide two pre-defined dispatching functions splitting uniformly and logarithmically. User can define their own customized dispatching function.

## Prune buckets though boundary for symbolic axis.
In most practical cases, we don't really need a complete range [1, +inf) for symbolic axis. Boundary for tvm.var can greatly reduce the number of buckets and thus the number of kernel functions. In this design we don't consider any boundary pruning yet. We might want to leverage the idea in this topic: https://discuss.tvm.ai/t/discuss-embed-more-bound-information-into-var-or-expr/4079.

## A working example:
```python
input_name = "data"
input_shape = [tvm.relay.Any(), 3, 224, 224]
dtype = "float32"
block = get_model('resnet50_v1', pretrained=True)
mod, params = relay.frontend.from_mxnet(block, shape={input_name: input_shape}, dtype=dtype)
mod.dispatch("main", {input_name: input_shape}, tvm.relay.vm.log_dispatcher)
vmc = relay.backend.vm.VMCompiler()
with tvm.autotvm.apply_graph_best("resnet50_v1_graph_opt.log"):
    vm = vmc.compile(mod, "llvm")
                     
vm.init(ctx)
vm.load_params(params)

data = np.random.uniform(size=(1, 3, 224, 224)).astype("float32")
out = vm.run(data)

data = np.random.uniform(size=(4, 3, 224, 224)).astype("float32")
out = vm.run(data)
```

## TODO

- [ ] Relay module dispatch function.
- [ ] Shape functions for most common operators in CV models.
- [ ] Graph tuner changes to tune a dispatched graph.

@tqchen @jroesch @icemelon9 @comaniac @sxjscience @yzhliu @wweic @zhiics @yongwww @antinucleon @junrushao1994 

-- 
You are receiving this because you are subscribed to this thread.
Reply to this email directly or view it on GitHub:
https://github.com/dmlc/tvm/issues/4118

Re: [dmlc/tvm] [RFC] Dynamic Shape Support - Graph Dispatching (#4118)

Posted by Haichen Shen <no...@github.com>.
@soiferj 
1. Shape function is used to compute the output shape(s) of an op at runtime, which cannot be determined at compilation time. And yes, fow now, we have to register the shape function for all ops to support dynamic shape.
2. We could do this. But we need to change the attribute of `full` op to let it take non-constant shapes.

-- 
You are receiving this because you are subscribed to this thread.
Reply to this email directly or view it on GitHub:
https://github.com/dmlc/tvm/issues/4118#issuecomment-545730348

Re: [dmlc/tvm] [RFC] Dynamic Shape Support - Graph Dispatching (#4118)

Posted by Yao Wang <no...@github.com>.
@soiferj For ```full``` op, we can change the input shape argument to be relay.Expr. We use hybrid script to register shape functions, since most of them are not easy to be written as tensor expression. We only add CPU version shape functions, and relay on Heterogeneous execution for gpu.

-- 
You are receiving this because you are subscribed to this thread.
Reply to this email directly or view it on GitHub:
https://github.com/dmlc/tvm/issues/4118#issuecomment-545759113

Re: [dmlc/tvm] [RFC] Dynamic Shape Support - Graph Dispatching (#4118)

Posted by Tianqi Chen <no...@github.com>.
Thanks for the proposal. One high level comment: ideally we want to keep the module API minimum, and move transformation-like operations to the transform namespace :)

-- 
You are receiving this because you are subscribed to this thread.
Reply to this email directly or view it on GitHub:
https://github.com/dmlc/tvm/issues/4118#issuecomment-545718395

Re: [dmlc/tvm] [RFC] Dynamic Shape Support - Graph Dispatching (#4118)

Posted by Jon Soifer <no...@github.com>.
Thanks a lot for working on this, this is going to be really impactful, especially toward supporting NLP models. I have a couple of questions:

1. Can you please explain the shape function in a little more detail? What exactly is its purpose? Will it have to be registered for every op?
2. Some ops, like `full` take their shape argument as a constant list. With this change, we could potentially support either a constant list or a relay expression that is unknown at compile time. How would that work? Would the operator definition have to change?

-- 
You are receiving this because you are subscribed to this thread.
Reply to this email directly or view it on GitHub:
https://github.com/dmlc/tvm/issues/4118#issuecomment-545681963

Re: [dmlc/tvm] [RFC] Dynamic Shape Support - Graph Dispatching (#4118)

Posted by Yao Wang <no...@github.com>.
@tqchen Sure. Dispatch function doesn't need to couple with relay::Module.

-- 
You are receiving this because you are subscribed to this thread.
Reply to this email directly or view it on GitHub:
https://github.com/dmlc/tvm/issues/4118#issuecomment-545759382