You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by co...@apache.org on 2021/05/04 08:06:43 UTC

[tvm] branch main updated: [BYOC][TensorRT] Fixes for explicit batch mode, Support reduce to scalar, Support split op (#7967)

This is an automated email from the ASF dual-hosted git repository.

comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0e3d850  [BYOC][TensorRT] Fixes for explicit batch mode, Support reduce to scalar, Support split op (#7967)
0e3d850 is described below

commit 0e3d850983d143f043183a68bd69b073c4288222
Author: Trevor Morris <tr...@amazon.com>
AuthorDate: Tue May 4 01:06:32 2021 -0700

    [BYOC][TensorRT] Fixes for explicit batch mode, Support reduce to scalar, Support split op (#7967)
---
 python/tvm/relay/op/contrib/tensorrt.py          | 30 ++++++++++--
 src/relay/backend/contrib/tensorrt/codegen.cc    | 31 ++++++++++++
 src/runtime/contrib/tensorrt/tensorrt_ops.cc     | 62 ++++++++++++++++++++++--
 src/runtime/contrib/tensorrt/tensorrt_runtime.cc |  3 +-
 tests/python/contrib/test_tensorrt.py            | 13 +++++
 5 files changed, 130 insertions(+), 9 deletions(-)

diff --git a/python/tvm/relay/op/contrib/tensorrt.py b/python/tvm/relay/op/contrib/tensorrt.py
index a36b66c..ab36b8b 100644
--- a/python/tvm/relay/op/contrib/tensorrt.py
+++ b/python/tvm/relay/op/contrib/tensorrt.py
@@ -266,7 +266,7 @@ _register_external_op_helper("clip")
 
 def reduce_annotate_fn(attrs, args, op_name):
     """Helper for reduce operations."""
-    if not attrs.axis or len(attrs.axis) == 0:
+    if get_tensorrt_use_implicit_batch_mode() and (not attrs.axis or len(attrs.axis) == 0):
         logger.info("%s: cannot reduce to scalar.", op_name)
         return False
     if attrs.exclude:
@@ -317,10 +317,9 @@ def add_annotate_fn(expr):  # pylint: disable=unused-variable
         for arg in args
     ]
 
-    # RelayVM + TRT doesn't support scalar addition yet.
-    for shape in shapes:
-        if len(shape) < 1:
-            return False
+    # Scalars require explicit batch mode.
+    if get_tensorrt_use_implicit_batch_mode() and any([len(shape) < 1 for shape in shapes]):
+        return False
 
     if any([x.checked_type.dtype != "float32" for x in args]):
         logger.info("Only float32 inputs are supported for TensorRT.")
@@ -328,6 +327,8 @@ def add_annotate_fn(expr):  # pylint: disable=unused-variable
     if (
         not get_tensorrt_use_implicit_batch_mode()
         and (isinstance(args[0], Constant) or isinstance(args[1], Constant))
+        and len(shapes[0]) > 0
+        and len(shapes[1]) > 0
         and shapes[0][0] == shapes[1][0]
         and shapes[0][0] != 1
         and (len(shapes[0]) > 3 or len(shapes[1]) > 3)
@@ -552,6 +553,19 @@ def concatenate_annotate_fn(expr):  # pylint: disable=unused-variable
     return True
 
 
+@_register_external_dynamic_check_func("split")
+def split_annotate_fn(expr):
+    """Check if split is supported by TensorRT."""
+
+    if any([x.checked_type.dtype != "float32" for x in expr.args]):
+        logger.info("Only float32 inputs are supported for TensorRT.")
+        return False
+    if get_tensorrt_use_implicit_batch_mode() and int(expr.attrs.axis) == 0:
+        logger.info("split: can't modify batch dimension.")
+        return False
+    return True
+
+
 @_register_external_dynamic_check_func("nn.conv2d_transpose")
 def conv2d_transpose_annotate_fn(expr):  # pylint: disable=unused-variable
     """Check if nn.conv2d_transpose is supported by TensorRT."""
@@ -870,6 +884,11 @@ class IsComputeIntensiveGraph(ExprVisitor):
                 "nn.conv3d_transpose",
                 "nn.dense",
                 "nn.batch_matmul",
+                "sum",
+                "prod",
+                "max",
+                "min",
+                "mean",
             ]
         )
         if isinstance(call.op, tvm.tir.op.Op):
@@ -968,6 +987,7 @@ def prune_tensorrt_subgraphs(mod):
     # Create new pruned module
     new_mod = tvm.IRModule(mod.functions, mod.type_definitions)
     new_mod["main"] = SubgraphRemover(subgraphs_to_remove, mod, new_mod).visit(mod["main"])
+    new_mod = transform.RemoveUnusedFunctions()(new_mod)
     return new_mod
 
 
diff --git a/src/relay/backend/contrib/tensorrt/codegen.cc b/src/relay/backend/contrib/tensorrt/codegen.cc
index e121b60..d83a900 100644
--- a/src/relay/backend/contrib/tensorrt/codegen.cc
+++ b/src/relay/backend/contrib/tensorrt/codegen.cc
@@ -99,6 +99,8 @@ class TensorRTJSONSerializer : public backend::contrib::JSONSerializer {
       SetPadNodeAttribute(node, cn);
     } else if (name == "strided_slice") {
       SetStridedSliceNodeAttribute(node, cn);
+    } else if (name == "split") {
+      SetSplitNodeAttribute(node, cn);
     } else {
       SetCallNodeAttribute(node, cn);
     }
@@ -172,6 +174,35 @@ class TensorRTJSONSerializer : public backend::contrib::JSONSerializer {
     node->SetAttr("strides", strides_attr);
   }
 
+  void SetSplitNodeAttribute(std::shared_ptr<JSONGraphNode> node, const CallNode* cn) {
+    const auto* split_attr = cn->attrs.as<SplitAttrs>();
+    ICHECK(split_attr);
+
+    std::vector<std::string> indices_or_sections;
+    std::vector<std::string> mode;
+    std::vector<std::string> axis = {std::to_string(split_attr->axis)};
+    if (const IntImmNode* sections = split_attr->indices_or_sections.as<IntImmNode>()) {
+      mode.emplace_back("sections");
+      indices_or_sections.emplace_back(std::to_string(sections->value));
+    } else {
+      mode.emplace_back("indices");
+      auto indices = Downcast<tvm::Array<Integer>>(split_attr->indices_or_sections);
+      for (const auto& i : indices) {
+        indices_or_sections.emplace_back(std::to_string(i->value));
+      }
+    }
+
+    std::vector<dmlc::any> indices_or_sections_attr;
+    std::vector<dmlc::any> mode_attr;
+    std::vector<dmlc::any> axis_attr;
+    indices_or_sections_attr.emplace_back(indices_or_sections);
+    mode_attr.emplace_back(mode);
+    axis_attr.emplace_back(axis);
+    node->SetAttr("indices_or_sections", indices_or_sections_attr);
+    node->SetAttr("mode", mode_attr);
+    node->SetAttr("axis", axis_attr);
+  }
+
   void SaveGlobalAttributes(std::shared_ptr<JSONGraphNode> node) {
     auto ctx = transform::PassContext::Current();
     auto cfg = ctx->GetConfig<TensorRTCompilerConfig>("relay.ext.tensorrt.options");
diff --git a/src/runtime/contrib/tensorrt/tensorrt_ops.cc b/src/runtime/contrib/tensorrt/tensorrt_ops.cc
index 04b1e83..9b108fa 100644
--- a/src/runtime/contrib/tensorrt/tensorrt_ops.cc
+++ b/src/runtime/contrib/tensorrt/tensorrt_ops.cc
@@ -723,6 +723,53 @@ class ConcatOpConverter : public TensorRTOpConverter {
   }
 };
 
+class SplitOpConverter : public TensorRTOpConverter {
+ public:
+  SplitOpConverter() : TensorRTOpConverter({kTensor}) {}
+
+  void Convert(TensorRTOpConverterParams* params) const {
+    auto input = params->inputs.at(0).tensor;
+    auto input_dims = TrtDimsToVector(input->getDimensions());
+    const int original_axis = std::stoi(params->node.GetAttr<std::vector<std::string>>("axis")[0]);
+    const int axis = ConvertAxis(params, original_axis, input_dims.size());
+    auto indices_or_sections =
+        params->node.GetAttr<std::vector<std::string>>("indices_or_sections");
+    auto mode = params->node.GetAttr<std::vector<std::string>>("mode")[0];
+
+    std::vector<int> split_starts;
+    std::vector<int> split_sizes;
+    if (mode == "sections") {
+      int sections = std::stoi(indices_or_sections[0]);
+      int size = input_dims[axis] / sections;
+      for (int i = 0; i < sections; i++) {
+        split_starts.push_back(i * size);
+        split_sizes.push_back(size);
+      }
+    } else {
+      int last_index = 0;
+      for (size_t i = 0; i < indices_or_sections.size(); ++i) {
+        int index = std::stoi(indices_or_sections[i]);
+        split_starts.push_back(last_index);
+        split_sizes.push_back(index - last_index);
+        last_index = index;
+      }
+      split_starts.push_back(last_index);
+      split_sizes.push_back(input_dims[axis] - last_index);
+    }
+
+    std::vector<int> start(input_dims.size(), 0);
+    std::vector<int> size(input_dims.begin(), input_dims.end());
+    std::vector<int> strides(input_dims.size(), 1);
+    for (int i = 0; i < split_sizes.size(); ++i) {
+      start[axis] = split_starts[i];
+      size[axis] = split_sizes[i];
+      auto slice_layer = params->network->addSlice(*input, VectorToTrtDims(start),
+                                                   VectorToTrtDims(size), VectorToTrtDims(strides));
+      params->outputs.push_back(slice_layer->getOutput(0));
+    }
+  }
+};
+
 class BiasAddOpConverter : public TensorRTOpConverter {
  public:
   BiasAddOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {}
@@ -970,9 +1017,17 @@ class ReduceOpConverter : public TensorRTOpConverter {
     // TODO(trevmorr): Support reduce to scalar.
     ICHECK_GT(str_axis.size(), 0);
     uint32_t reduce_axes = 0;
-    for (size_t i = 0; i < str_axis.size(); ++i) {
-      const int axis = ConvertAxis(params, std::stoi(str_axis[i]), input->getDimensions().nbDims);
-      reduce_axes |= 1 << axis;
+
+    if (str_axis.size() == 1 && str_axis[0].length() == 0) {
+      // Reduce to scalar
+      for (int i = 0; i < input->getDimensions().nbDims; ++i) {
+        reduce_axes |= 1 << i;
+      }
+    } else {
+      for (size_t i = 0; i < str_axis.size(); ++i) {
+        const int axis = ConvertAxis(params, std::stoi(str_axis[i]), input->getDimensions().nbDims);
+        reduce_axes |= 1 << axis;
+      }
     }
     auto reduce_layer = params->network->addReduce(*input, it->second, reduce_axes, keepdims);
     params->outputs.push_back(reduce_layer->getOutput(0));
@@ -1072,6 +1127,7 @@ GetOpConverters() {
   map->emplace("expand_dims", std::make_shared<ExpandDimsOpConverter>());
   map->emplace("squeeze", std::make_shared<SqueezeOpConverter>());
   map->emplace("concatenate", std::make_shared<ConcatOpConverter>());
+  map->emplace("split", std::make_shared<SplitOpConverter>());
   map->emplace("nn.conv2d_transpose", std::make_shared<Conv2DTransposeOpConverter>());
   map->emplace("transpose", std::make_shared<TransposeOpConverter>());
   map->emplace("layout_transform", std::make_shared<LayoutTransformOpConverter>());
diff --git a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc
index 21031c6..7efa5bf 100644
--- a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc
+++ b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc
@@ -185,7 +185,8 @@ class TensorRTRuntime : public JSONRuntimeBase {
    * do nothing.
    */
   void BuildEngine() {
-    batch_size_ = data_entry_[input_var_eid_[0]]->shape[0];
+    batch_size_ =
+        data_entry_[input_var_eid_[0]]->ndim == 0 ? 1 : data_entry_[input_var_eid_[0]]->shape[0];
     if (trt_engine_cache_.count(std::make_pair(symbol_name_, batch_size_))) return;
     DLOG(INFO) << "Building new TensorRT engine for subgraph " << symbol_name_
                << " with batch size " << batch_size_;
diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py
index 2bef7be..c6b714c 100644
--- a/tests/python/contrib/test_tensorrt.py
+++ b/tests/python/contrib/test_tensorrt.py
@@ -605,6 +605,19 @@ def test_concatenate():
     run_and_verify_func(get_graph([(1, 2, 6, 6), (1, 3, 6, 6)], axis=1))
 
 
+def test_split():
+    def get_graph(x_shape, indices_or_sections, axis):
+        x = relay.var("x", shape=(x_shape), dtype="float32")
+        out = relay.split(x, indices_or_sections=indices_or_sections, axis=axis)
+        f = relay.Function([x], out.astuple())
+        return f, {"x": x_shape}, []
+
+    run_and_verify_func(get_graph((1, 16), indices_or_sections=2, axis=1))
+    run_and_verify_func(get_graph((1, 16), indices_or_sections=4, axis=1))
+    run_and_verify_func(get_graph((1, 16), indices_or_sections=[8], axis=1))
+    run_and_verify_func(get_graph((1, 16), indices_or_sections=[2, 3, 6, 10, 14], axis=1))
+
+
 def test_conv2d_transpose():
     def get_graph(
         x_shape=(1, 32, 8, 8),