You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/06/23 06:10:32 UTC

[tvm] branch main updated: [Relay][Pass] Meta-Schedule-Layout-Rewrite (#11845)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b47725627e [Relay][Pass] Meta-Schedule-Layout-Rewrite (#11845)
b47725627e is described below

commit b47725627e264aba5679534db4589f540d4bcd7c
Author: Hongyi Jin <32...@qq.com>
AuthorDate: Thu Jun 23 14:10:25 2022 +0800

    [Relay][Pass] Meta-Schedule-Layout-Rewrite (#11845)
---
 include/tvm/relay/transform.h                      |   6 +
 src/relay/backend/build_module.cc                  |  14 ++
 src/relay/backend/vm/compiler.cc                   |  14 ++
 .../transforms/meta_schedule_layout_rewrite.cc     | 175 +++++++++++++++++++++
 .../transforms/meta_schedule_layout_rewrite.h      |  38 +++++
 5 files changed, 247 insertions(+)

diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h
index 6e3bddf9ad..b592265c74 100644
--- a/include/tvm/relay/transform.h
+++ b/include/tvm/relay/transform.h
@@ -371,6 +371,12 @@ TVM_DLL Pass AlterOpLayout();
  */
 TVM_DLL Pass AutoSchedulerLayoutRewrite();
 
+/*!
+ * \brief Do layout rewrite according to the tile structure created by meta-schedule.
+ * \return The pass
+ */
+TVM_DLL Pass MetaScheduleLayoutRewrite();
+
 /*!
  * \brief Given a dest layout, this pass transforms the expr such that most of the ops input data
  * layout is changed to the dest layout. In ideal situation, there are only 2 layout transforms, one
diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc
index 578a62ca02..628dee0844 100644
--- a/src/relay/backend/build_module.cc
+++ b/src/relay/backend/build_module.cc
@@ -379,6 +379,20 @@ class RelayBuildModule : public runtime::ModuleNode {
         relay_module = transform::FuseOps()(relay_module);
       }
     }
+    if (backend::IsMetaScheduleEnabled() && config_->optional_homogeneous_target.defined()) {
+      Pass major_pass = transform::MetaScheduleLayoutRewrite();
+      bool enable_layout_rewrite_targets =
+          config_->optional_homogeneous_target->kind->device_type == kDLCPU ||
+          config_->optional_homogeneous_target->GetAttr<String>("device", "") == "mali";
+      if (enable_layout_rewrite_targets && pass_ctx.PassEnabled(major_pass->Info())) {
+        With<Target> tctx(config_->optional_homogeneous_target);
+        relay_module = major_pass(relay_module);
+        // Defuse ops to fold constants, then fuse them again
+        relay_module = transform::DefuseOps()(relay_module);
+        relay_module = transform::FoldConstant()(relay_module);
+        relay_module = transform::FuseOps()(relay_module);
+      }
+    }
 
     relay_module = transform::InferType()(relay_module);
 
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index 8820a403bf..7371fd1f80 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -1078,6 +1078,20 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) {
       pass_seqs.push_back(transform::FuseOps());
     }
   }
+  if (backend::IsMetaScheduleEnabled() && config_->optional_homogeneous_target.defined()) {
+    Pass major_pass = transform::MetaScheduleLayoutRewrite();
+    bool enable_layout_rewrite_targets =
+        config_->optional_homogeneous_target->kind->device_type == kDLCPU ||
+        config_->optional_homogeneous_target->GetAttr<String>("device", "") == "mali";
+    if (enable_layout_rewrite_targets && pass_ctx.PassEnabled(major_pass->Info())) {
+      With<Target> tctx(config_->optional_homogeneous_target);
+      pass_seqs.push_back(major_pass);
+      // Defuse ops to fold constants, then fuse them again
+      pass_seqs.push_back(transform::DefuseOps());
+      pass_seqs.push_back(transform::FoldConstant());
+      pass_seqs.push_back(transform::FuseOps());
+    }
+  }
 
   pass_seqs.push_back(transform::ToANormalForm());
   pass_seqs.push_back(transform::InferType());
diff --git a/src/relay/transforms/meta_schedule_layout_rewrite.cc b/src/relay/transforms/meta_schedule_layout_rewrite.cc
new file mode 100644
index 0000000000..b817802f17
--- /dev/null
+++ b/src/relay/transforms/meta_schedule_layout_rewrite.cc
@@ -0,0 +1,175 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "./meta_schedule_layout_rewrite.h"
+
+#include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/transform.h>
+
+#include <deque>
+#include <mutex>
+#include <vector>
+
+#include "../backend/te_compiler.h"
+
+namespace tvm {
+namespace relay {
+
+class LayoutIndexQueue {
+ public:
+  static LayoutIndexQueue* Global() {
+    static LayoutIndexQueue inst;
+    return &inst;
+  }
+
+  void Clear() {
+    std::lock_guard<std::mutex> lock(mutex_);
+    queue_.clear();
+  }
+
+ private:
+  friend class MetaScheduleLayoutRewriter;
+  std::mutex mutex_;
+  std::deque<tir::IndexMap> queue_;
+};
+
+void MetaScheduleLayoutRewriter::LayoutQueuePush(const tir::IndexMap& index_map) {
+  LayoutIndexQueue* self = LayoutIndexQueue::Global();
+  {
+    std::lock_guard<std::mutex> lock(self->mutex_);
+    self->queue_.push_back(index_map);
+  }
+}
+
+bool IsSupportedOp(const OpNode* op) {
+  static std::vector<std::string> target_ops{
+      "nn.conv2d",  //
+      "nn.contrib_conv2d_winograd_without_weight_transform",
+      "nn.conv3d",
+      "nn.matmul",
+      "nn.dense",
+      "nn.batch_matmul",
+  };
+  return std::find(target_ops.begin(), target_ops.end(), op->name) != target_ops.end();
+}
+
+#define TVM_RELAY_LAYOUT_WITH_ORIGINAL_SHAPE(Attr, AttrType, OriginalShape, Result) \
+  if (const AttrType* ptr = Attr.as<AttrType>()) {                                  \
+    ObjectPtr<AttrType> n = make_object<AttrType>(*ptr);                            \
+    n->meta_schedule_original_shape = OriginalShape;                                \
+    Result = Attrs(n);                                                              \
+  }
+
+// Mutate ops in a function
+class MetaScheduleFuncMutator : public ExprMutator {
+ public:
+  explicit MetaScheduleFuncMutator(std::deque<tir::IndexMap>&& layout_queue)
+      : layout_queue_(std::move(layout_queue)) {}
+
+  Expr VisitExpr_(const CallNode* call) {
+    Expr expr = ExprMutator::VisitExpr_(call);
+    if (layout_queue_.empty()) {
+      return expr;
+    }
+    if (const auto* call = expr.as<CallNode>()) {
+      if (const auto* op = call->op.as<OpNode>()) {
+        if (IsSupportedOp(op)) {
+          ICHECK_EQ(call->args.size(), 2);
+          tir::IndexMap index_map = layout_queue_.front();
+          layout_queue_.pop_front();
+          Var var = Downcast<Var>(call->args[1]);
+          Array<PrimExpr> shape = Downcast<TensorType>(var->type_annotation)->shape;
+          Attrs attrs{nullptr};
+          TVM_RELAY_LAYOUT_WITH_ORIGINAL_SHAPE(call->attrs, Conv2DAttrs, shape, attrs);
+          TVM_RELAY_LAYOUT_WITH_ORIGINAL_SHAPE(call->attrs, Conv2DWinogradAttrs, shape, attrs);
+          TVM_RELAY_LAYOUT_WITH_ORIGINAL_SHAPE(call->attrs, Conv3DAttrs, shape, attrs);
+          TVM_RELAY_LAYOUT_WITH_ORIGINAL_SHAPE(call->attrs, MatmulAttrs, shape, attrs);
+          TVM_RELAY_LAYOUT_WITH_ORIGINAL_SHAPE(call->attrs, DenseAttrs, shape, attrs);
+          TVM_RELAY_LAYOUT_WITH_ORIGINAL_SHAPE(call->attrs, BatchMatmulAttrs, shape, attrs);
+          ICHECK(attrs.defined()) << "TypeError: Unknown attribute: " << call->attrs;
+          expr = Call(call->op,
+                      {call->args[0], MakeMetaScheduleLayoutTransform(call->args[1], index_map)},
+                      attrs);
+        }
+      }
+    }
+    return expr;
+  }
+
+ private:
+  std::deque<tir::IndexMap> layout_queue_;
+};
+
+#undef TVM_RELAY_LAYOUT_WITH_ORIGINAL_SHAPE
+
+Expr MetaScheduleLayoutRewriter::VisitExpr_(const CallNode* call) {
+  Expr expr = ExprMutator::VisitExpr_(call);
+  call = expr.as<CallNode>();
+  if (call != nullptr) {
+    if (const auto* func = call->op.as<FunctionNode>()) {
+      LayoutIndexQueue* self = LayoutIndexQueue::Global();
+      self->queue_.clear();
+      tec::PrimFuncFor(GetRef<Function>(func), Target::Current(),
+                       [](std::string name) { return name; });
+      if (!self->queue_.empty()) {
+        std::deque<tir::IndexMap> queue = std::move(self->queue_);
+        self->queue_.clear();
+        return MetaScheduleFuncMutator(std::move(queue)).VisitExpr(expr);
+      }
+    }
+  }
+  return expr;
+}
+
+namespace transform {
+
+Pass MetaScheduleLayoutRewrite() {
+  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
+      [=](Function f, IRModule m, PassContext pc) -> Function {
+    return Downcast<Function>(MetaScheduleLayoutRewriter().Mutate(std::move(f)));
+  };
+  return CreateFunctionPass(pass_func, 3, "MetaScheduleLayoutRewrite", {"InferType"});
+}
+
+#define TVM_RELAY_META_SCHEDULE_LAYOUT_REWRITE_GET_ORIGINAL_SHAPE(Attrs, AttrType) \
+  if (const auto* p = Attrs.as<AttrType>()) {                                      \
+    return p->meta_schedule_original_shape;                                        \
+  }
+
+TVM_REGISTER_GLOBAL("relay.attrs.get_meta_schedule_original_shape")
+    .set_body_typed([](const Attrs& attrs) -> Array<PrimExpr> {
+      TVM_RELAY_META_SCHEDULE_LAYOUT_REWRITE_GET_ORIGINAL_SHAPE(attrs, Conv2DAttrs);
+      TVM_RELAY_META_SCHEDULE_LAYOUT_REWRITE_GET_ORIGINAL_SHAPE(attrs, Conv2DWinogradAttrs);
+      TVM_RELAY_META_SCHEDULE_LAYOUT_REWRITE_GET_ORIGINAL_SHAPE(attrs, Conv3DAttrs);
+      TVM_RELAY_META_SCHEDULE_LAYOUT_REWRITE_GET_ORIGINAL_SHAPE(attrs, MatmulAttrs);
+      TVM_RELAY_META_SCHEDULE_LAYOUT_REWRITE_GET_ORIGINAL_SHAPE(attrs, DenseAttrs);
+      TVM_RELAY_META_SCHEDULE_LAYOUT_REWRITE_GET_ORIGINAL_SHAPE(attrs, BatchMatmulAttrs);
+      LOG(FATAL) << "TypeError: Unknown attribute: " << attrs;
+      throw;
+    });
+TVM_REGISTER_GLOBAL("relay._transform.MetaScheduleLayoutRewrite")
+    .set_body_typed(MetaScheduleLayoutRewrite);
+
+#undef TVM_RELAY_META_SCHEDULE_LAYOUT_REWRITE_GET_ORIGINAL_SHAPE
+
+}  // namespace transform
+}  // namespace relay
+}  // namespace tvm
diff --git a/src/relay/transforms/meta_schedule_layout_rewrite.h b/src/relay/transforms/meta_schedule_layout_rewrite.h
new file mode 100644
index 0000000000..f60df9b3e2
--- /dev/null
+++ b/src/relay/transforms/meta_schedule_layout_rewrite.h
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_RELAY_TRANSFORMS_META_SCHEDULE_LAYOUT_REWRITE_H_
+#define TVM_RELAY_TRANSFORMS_META_SCHEDULE_LAYOUT_REWRITE_H_
+
+#include <tvm/relay/expr_functor.h>
+#include <tvm/tir/index_map.h>
+
+namespace tvm {
+namespace relay {
+
+class MetaScheduleLayoutRewriter : public ExprMutator {
+ public:
+  Expr VisitExpr_(const CallNode* n) final;
+
+  static void LayoutQueuePush(const tir::IndexMap& index_map);
+};
+
+}  // namespace relay
+}  // namespace tvm
+
+#endif  // TVM_RELAY_TRANSFORMS_META_SCHEDULE_LAYOUT_REWRITE_H_