You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/07/16 00:22:30 UTC

[GitHub] [incubator-tvm] comaniac commented on a change in pull request #6062: [Relay][Pass] Support combine multiple dense op just into dense

comaniac commented on a change in pull request #6062:
URL: https://github.com/apache/incubator-tvm/pull/6062#discussion_r455434066



##########
File path: include/tvm/relay/transform.h
##########
@@ -223,10 +223,11 @@ TVM_DLL Pass CombineParallelConv2D(uint64_t min_num_branches = 3);
  * `min_num_branch`.
  *
  * \param min_num_branches The minimun number of branches.
+ * \param to_batch Combine matmuls to batch matmul.
  *
  * \return The pass.
  */
-TVM_DLL Pass CombineParallelDense(uint64_t min_num_branches = 3);
+TVM_DLL Pass CombineParallelDense(uint64_t min_num_branches = 3, bool to_batch = true);

Review comment:
       The name `to_batch` could be improved as it is unclear what will be if `to_batch=false`. At least we should improve the docstring to say the output would be a single dense op when `to_batch=false`.

##########
File path: src/relay/backend/vm/compiler.cc
##########
@@ -948,7 +948,7 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
   pass_seqs.push_back(transform::InlinePrimitives());
 
   pass_seqs.push_back(transform::CombineParallelConv2D(3));
-  pass_seqs.push_back(transform::CombineParallelDense(3));
+  pass_seqs.push_back(transform::CombineParallelDense(3, true));

Review comment:
       ditto.

##########
File path: src/relay/transforms/combine_parallel_dense.cc
##########
@@ -68,17 +71,168 @@ class ParallelDenseCombiner : public ParallelOpBatchCombiner {
   }
 };
 
+/*
+ * Class that find and combine parallel dense ops into one dense op
+ * whose num of output units equals to sum of each sub-ops.
+ */
+class ParallelDenseFlatCombiner : public ParallelOpCombiner {
+ public:
+  explicit ParallelDenseFlatCombiner(uint64_t min_num_branches)
+      : ParallelOpCombiner("nn.dense", min_num_branches) {}
+
+ protected:
+  bool IsSupportedOp(const CallNode* n) { return true; }
+
+  bool CanOpsBeCombined(const CallNode* a, const CallNode* b) {
+    StructuralEqual eq;
+    const auto* attrs_a = a->attrs.as<DenseAttrs>();
+    const auto* attrs_b = b->attrs.as<DenseAttrs>();
+    const auto* weight_a = a->args[1]->type_as<TensorTypeNode>();
+    const auto* weight_b = b->args[1]->type_as<TensorTypeNode>();
+    CHECK(attrs_a != nullptr && attrs_b != nullptr && weight_a != nullptr && weight_b != nullptr);
+    // output dims (weight->shape[0]) can be different
+    return eq(attrs_a->out_dtype, attrs_b->out_dtype) && eq(weight_a->shape[1], weight_b->shape[1]);
+  }
+
+  Call MakeCombinedOp(const Group& branches) {
+    const Op& dense_op = Op::Get("nn.dense");
+    Expr input = branches[0][0]->args[0];
+    Expr new_weight;
+    IndexExpr new_output_dims;
+    // concat all weights into one
+    std::tie(new_weight, new_output_dims) = TransformWeight(branches);
+    const auto* origin_attrs = branches[0][0]->attrs.as<DenseAttrs>();
+    CHECK(origin_attrs);
+    const auto dense_attrs = make_object<DenseAttrs>();
+    dense_attrs->units = new_output_dims;
+    dense_attrs->out_dtype = origin_attrs->out_dtype;
+    return Call(dense_op, {input, new_weight}, Attrs{dense_attrs}, {});
+  }
+
+  bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) {
+    StructuralEqual eq;
+    auto ta = a->args[index]->type_as<TensorTypeNode>();
+    auto tb = b->args[index]->type_as<TensorTypeNode>();
+    auto toutput_a = a->type_as<TensorTypeNode>();
+    auto toutput_b = b->type_as<TensorTypeNode>();
+    CHECK(ta != nullptr && tb != nullptr && toutput_a != nullptr && toutput_b != nullptr);
+
+    if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size()) {
+      return false;
+    }
+    if (toutput_a->shape.size() < ta->shape.size() || toutput_b->shape.size() < tb->shape.size()) {
+      return false;  // not broadcast/elemwise
+    }
+    if (ta->shape.size() > 0) {
+      for (size_t i = 0; i < ta->shape.size() - 1; i++) {
+        // shape dims must match except last dim
+        if (!eq(ta->shape[i], tb->shape[i])) return false;
+      }
+    }
+    return true;
+  }
+
+  Call MakeCombinedCallFromFollowingOps(const Expr& data, const Group& branches, size_t depth,
+                                        size_t parent_index) {
+    Array<Expr> new_args;
+    const CallNode* call = branches[0][depth];
+    for (size_t i = 0; i < call->args.size(); i++) {
+      if (i == parent_index) {
+        new_args.push_back(data);
+        continue;
+      }
+      size_t arg_ndim = call->args[i]->type_as<TensorTypeNode>()->shape.size();
+      size_t concat_axis = arg_ndim == 0 ? 0 : arg_ndim - 1;
+      Array<Expr> tuple;
+      for (const auto& branch : branches) {
+        auto parent = branch[depth]->args[parent_index];
+        auto& parent_shape = parent->type_as<TensorTypeNode>()->shape;
+        auto out_dim = tir::as_const_int(parent_shape[parent_shape.size() - 1]);
+        CHECK(out_dim != nullptr);
+
+        auto arg = branch[depth]->args[i];
+        auto& arg_shape = arg->type_as<TensorTypeNode>()->shape;
+        bool repeat_last_dim = false;
+        if (arg_ndim == 0) {
+          repeat_last_dim = true;
+          arg = MakeExpandDims(arg, -1, 1);
+        } else {
+          auto arg_last_dim = tir::as_const_int(arg_shape[arg_shape.size() - 1]);
+          CHECK(arg_last_dim != nullptr);
+          if (*out_dim > 1 && *arg_last_dim == 1) {
+            repeat_last_dim = true;
+          }
+        }
+        if (repeat_last_dim) {
+          // ensure broadcast is valid after concat args
+          arg = MakeRepeat(arg, *out_dim, concat_axis);
+        }
+        tuple.push_back(arg);
+      }
+      auto concat = MakeConcatenate(Tuple(tuple), concat_axis);
+      new_args.push_back(std::move(concat));
+    }
+    return Call(call->op, new_args, call->attrs, {});
+  }
+
+  void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth,
+                         ExprSubstMap* subst_map) {
+    int index = 0;
+    for (const auto& branch : branches) {
+      const CallNode* call = branch[depth];
+      auto& out_shape = call->type_as<TensorTypeNode>()->shape;
+      auto out_dims = tir::as_const_int(out_shape[out_shape.size() - 1]);
+      CHECK(out_dims != nullptr);
+      std::vector<int64_t> begin;
+      std::vector<int64_t> end;
+      std::vector<int64_t> strides;
+      for (size_t k = 0; k < out_shape.size() - 1; ++k) {
+        begin.push_back(0);
+        end.push_back(-1);
+        strides.push_back(1);
+      }
+      begin.push_back(index);
+      end.push_back(*out_dims);
+      strides.push_back(1);
+      index += *out_dims;
+      std::vector<int64_t> ndarray_shape = {static_cast<int64_t>(begin.size())};
+      Constant begin_const = MakeConstantTensor(DataType::Int(64), ndarray_shape, begin);
+      Constant end_const = MakeConstantTensor(DataType::Int(64), ndarray_shape, end);
+      Constant strides_const = MakeConstantTensor(DataType::Int(64), ndarray_shape, strides);
+      auto slice = MakeStridedSlice(data, begin_const, end_const, strides_const, "size");
+      subst_map->insert({GetRef<Expr>(branch[depth]), slice});
+    }
+  }
+
+ private:
+  std::tuple<Expr, IndexExpr> TransformWeight(const Group& branches) {
+    int64_t out_dims = 0;
+    Array<Expr> weights;
+    for (const auto& branch : branches) {
+      auto weight = branch[0]->args[1];
+      weights.push_back(weight);
+      out_dims += *tir::as_const_int(weight->type_as<TensorTypeNode>()->shape[0]);
+    }
+    return std::make_tuple(MakeConcatenate(Tuple(weights), 0),
+                           tir::make_const(DataType::Int(32), out_dims));
+  }
+};
+
 /*! \brief Combine parallel dense if number of branches >= min_num_branches */
-Expr CombineParallelDense(const Expr& expr, uint64_t min_num_branches) {
-  return ParallelDenseCombiner(min_num_branches).Combine(expr);
+Expr CombineParallelDense(const Expr& expr, uint64_t min_num_branches, bool to_batch) {
+  if (to_batch) {
+    return ParallelDenseBatchCombiner(min_num_branches).Combine(expr);
+  } else {
+    return ParallelDenseFlatCombiner(min_num_branches).Combine(expr);

Review comment:
       The names of two combiners are confusing. Would that be clearer to just have `ParallelDenseToBatchCombiner` and `ParallelDenseToDenseCombiner`?

##########
File path: src/relay/backend/build_module.cc
##########
@@ -277,7 +277,7 @@ class RelayBuildModule : public runtime::ModuleNode {
     });
     pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip));
     pass_seqs.push_back(transform::CombineParallelConv2D(3));
-    pass_seqs.push_back(transform::CombineParallelDense(3));
+    pass_seqs.push_back(transform::CombineParallelDense(3, true));

Review comment:
       You don't need to make this change if the default is true?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org