You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2019/11/26 06:58:16 UTC

[GitHub] [incubator-tvm] anijain2305 commented on a change in pull request #4335: [Relay][WIP] ConvertLayout pass.

anijain2305 commented on a change in pull request #4335: [Relay][WIP] ConvertLayout pass.
URL: https://github.com/apache/incubator-tvm/pull/4335#discussion_r350568244
 
 

 ##########
 File path: src/relay/pass/alter_op_layout.cc
 ##########
 @@ -36,333 +36,63 @@
 #include <utility>
 #include <unordered_map>
 
-#include "alter_op_layout.h"
+#include "transform_layout.h"
 #include "pattern_util.h"
 
 namespace tvm {
 namespace relay {
 
 namespace alter_op_layout {
 
-// Make a transform CallNode
-/* Performs 2 operations
- * 1) If src_layout ndim is smaller then dst_layout, expand_dim is inserted to match the dim size.
- *    For example, src_layout = C, dst_layout = NCHW16c. The src is expanded to NHWC.
- * 2) Call layout transform with new src layout.
- */
-Expr TransformLayout(Expr raw, Layout src_layout, Layout dst_layout) {
-  if (src_layout.Equals(dst_layout)) {
-    return raw;
-  }
-
-  // 1) Check if the shape lengths are different. If yes, expand dims.
-  Expr input_expr = raw;
-  Layout new_src_layout = src_layout;
-  if (src_layout.ndim_primal() < dst_layout.ndim_primal()) {
-    int num_new_axis = dst_layout.ndim_primal() - src_layout.ndim_primal();
-    new_src_layout = src_layout.ExpandPrimal(dst_layout);
-    input_expr = MakeExpandDims(input_expr, 0, num_new_axis);
-    if (new_src_layout.Equals(dst_layout)) {
-      return input_expr;
-    }
-  }
-
-  // 2) Insert layout transform on the transformed src.
-  CHECK(new_src_layout.defined() && dst_layout.defined())
-      << "Cannot insert layout transform because there are undefined layouts";
-  CHECK(BijectiveLayoutNode::make(new_src_layout, dst_layout).defined())
-      << "Cannot insert layout transform because there are inconvertible layouts: "
-      << new_src_layout << " v.s. " << dst_layout;
-  return MakeLayoutTransform(input_expr, new_src_layout.name(), dst_layout.name());
-}
-
-// Memorize layout transform so we can reuse internal transformed nodes
-class TransformMemorizerNode : public Node {
- public:
-  // map from (Expr, src_layout, dst_layout) to transformed Expr
-  using TransformKey = std::tuple<const Node*, std::string, std::string>;
-struct key_hash : public std::function<std::size_t(TransformKey)> {
-    std::size_t operator()(const TransformKey& k) const {
-      return dmlc::HashCombine<std::string>(dmlc::HashCombine<std::string>(
-              std::hash<const Node*>()(std::get<0>(k)), std::get<1>(k)), (std::get<2>(k)));
-    }
-  };
-
-  std::unordered_map<TransformKey, Expr, key_hash> memo;
-  static constexpr const char *_type_key = "relay.alter_op_layout.TransformMemorizerNode";
-  TVM_DECLARE_NODE_TYPE_INFO(TransformMemorizerNode, Node);
-};
-
-class TransformMemorizer : public NodeRef {
- public:
-  TransformMemorizer() {}
-  explicit TransformMemorizer(ObjectPtr<Object> n) : NodeRef(n) {}
-
-  TransformMemorizerNode* operator->() {
-    return static_cast<TransformMemorizerNode*>(get_mutable());
-  }
-
-  // Transform layout with memorizer
-  Expr Transform(Expr raw, const Layout& src_layout, const Layout& dst_layout) {
-    if (src_layout.Equals(dst_layout)) { return raw; }
-
-    std::tuple<const Node*, std::string, std::string> key =
-        std::make_tuple<>(raw.get(), src_layout.name(), dst_layout.name());
-    auto& memo = operator->()->memo;
-
-    auto iter = memo.find(key);
-    if (iter != memo.end()) {
-      return iter->second;
-    } else {
-      Expr transform = TransformLayout(raw, src_layout, dst_layout);
-      memo[key] = transform;
-      return transform;
-    }
-  }
-
-  using ContainerType = TransformMemorizerNode;
-};
-
-
-// TempExprNode during layout transform
-// Instance of this expr will be Realized to normal expr ultimately
-class LayoutAlternatedExprNode : public TempExprNode {
+class AlterNode : public TransformMemorizerNode {
  public:
-  Expr value;
-  Layout old_layout;
-  Layout new_layout;
-  TransformMemorizer memorizer;
-
-  Expr Realize() const final {
-    // NOTE: use a copy to discard the "const" qualifier
-    TransformMemorizer tmp_memorizer = memorizer;
-    // fallback to old layout
-    return tmp_memorizer.Transform(value, new_layout, old_layout);
-  }
-
-  void VisitAttrs(AttrVisitor *v) {
-    v->Visit("value", &value);
-    v->Visit("old_layout", &old_layout);
-    v->Visit("new_layout", &new_layout);
-  }
-
-  static constexpr const char *_type_key = "relay.alter_op_layout.LayoutAlternatedExprNode";
-  TVM_DECLARE_NODE_TYPE_INFO(LayoutAlternatedExprNode, TempExprNode);
-};
-
-RELAY_DEFINE_NODE_REF(LayoutAlternatedExpr, LayoutAlternatedExprNode, TempExpr);
-
-// Call registered FInferCorrectLayout of an op.
-// Parameters are the same as the parameters for FInferCorrectLayout
-// Returns inferred_input_layout, inferred_output_layout, success
-std::tuple<Array<Layout>, Array<Layout>, bool> CallInfer(
-    const Call& call,
-    const Array<Layout>& new_in_layouts,
-    const Array<Layout>& old_in_layouts,
-    const Array<Array<IndexExpr> > &old_in_shapes) {
-  static auto finfer_layout = Op::GetAttr<FInferCorrectLayout>("FInferCorrectLayout");
-
-  Op op = Downcast<Op>(call->op);
-  if (finfer_layout.count(op)) {
-    Array<Array<Layout> > inferred_layouts;
-    inferred_layouts = finfer_layout[op](call->attrs, new_in_layouts,
-                                         old_in_layouts, old_in_shapes);
-    CHECK_EQ(inferred_layouts.size(), 2)
-      << "FInferCorrectLayout should return an array with size of 2";
-    for (auto x : inferred_layouts) {
-      for (auto y : x) {
-        if (!y.defined()) {  // inference fails
-          return std::make_tuple<>(Array<Layout>(nullptr), Array<Layout>(nullptr), false);
-        }
+  /*!
+   * \brief Defines the call transformation for AlterOpLayout pass. The new layouts are defined by
+   * used for different targets using a packed func.
+   * \param ref_call The original call.
+   * \param new_args The traversed/recursed args to the call.
+   * \return The new Call after calling the packed func.
+   */
+  Call GetCallWithNewLayouts(const Call& ref_call, const std::vector<Expr>& new_args) const {
 
 Review comment:
   This function name is shared between Alter and Convert.  But, I will try to come up with a better name.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services