You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/12/10 21:19:34 UTC

[GitHub] [tvm] areusch commented on a change in pull request #7084: [WIP][TIR] Support Return in TIR

areusch commented on a change in pull request #7084:
URL: https://github.com/apache/tvm/pull/7084#discussion_r540496971



##########
File path: python/tvm/driver/build_module.py
##########
@@ -159,17 +159,30 @@ def lower(sch, args, name="main", binds=None, simple_mode=False):
     lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2]
     lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2]
 
+    is_tir_schedule = False
+
     # Phase 0
     if isinstance(sch, schedule.Schedule):
         mod = form_irmodule(sch, args, name, binds)
+    elif isinstance(sch, tvm.tir.PrimFunc):
+        func = sch.with_attr("global_symbol", name)
+        if pass_ctx.config.get("tir.restricted_func"):
+            func = func.with_attr("tir.noalias", True)
+        mod = tvm.IRModule({name: func})
+        is_tir_schedule = True
     else:
         mod = sch
 
     pass_list = lower_phase0
     # Phase 1
+    pass_list += [tvm.tir.transform.InjectPrefetch()]
+
+    if is_tir_schedule:
+        pass
+        # pass_list += [tvm.tir.transform.BufferFlatten()]

Review comment:
       remove the comment and invert the if

##########
File path: src/tir/transforms/make_packed_api.cc
##########
@@ -222,6 +275,7 @@ namespace transform {
 
 Pass MakePackedAPI(int num_unpacked_args) {
   auto pass_func = [num_unpacked_args](IRModule m, PassContext ctx) {
+    LOG(INFO) << "Before Make Packed API:\n" << m;

Review comment:
       remove

##########
File path: src/tir/transforms/make_packed_api.cc
##########
@@ -239,6 +293,7 @@ Pass MakePackedAPI(int num_unpacked_args) {
     for (const auto& pair : updates) {
       mptr->AddUnchecked(pair.first, pair.second);
     }
+    LOG(INFO) << "After Make Packed API:\n" << m;

Review comment:
       remove

##########
File path: src/tir/transforms/make_packed_api.cc
##########
@@ -41,6 +41,58 @@
 namespace tvm {
 namespace tir {
 
+class ReturnRewriter : public StmtMutator {
+ public:
+  explicit ReturnRewriter(Var ret_var, Var ret_tcode)
+    : ret_var_(ret_var), ret_tcode_(ret_tcode) {}
+
+  Stmt VisitStmt_(const EvaluateNode* node) override {
+    Stmt ret = StmtMutator::VisitStmt_(node);
+    const EvaluateNode* eval = ret.as<EvaluateNode>();
+    CHECK(eval);
+    if (const CallNode* call = eval->value.as<CallNode>()) {
+      if (call->op.same_as(builtin::myreturn())) {
+        CHECK_EQ(call->args.size(), 1);
+        ret = WriteToOut(call->args[0], ret_var_, ret_tcode_);
+      }
+    }
+    return ret;
+  }
+ private:
+  std::pair<int, PrimExpr> ConvertForFFI(PrimExpr val) {
+    DataType dtype = val.dtype();
+    if (dtype.is_int() || dtype.is_uint()) {
+      return {kTVMArgInt, Cast(DataType::Int(64), val)};
+    } else if (dtype.is_float()) {
+      return {kTVMArgFloat, Cast(DataType::Float(64), val)};
+    } else if (dtype.is_void()) {
+      return {kTVMNullptr, val};
+    } else {
+      LOG(FATAL) << "data type " << dtype << " not supported yet";

Review comment:
       I think we may need to return at least DLTensor for AOT

##########
File path: tests/python/unittest/test_tir_build.py
##########
@@ -0,0 +1,16 @@
+import tvm
+from tvm import tir
+
+def add():
+    a = tir.Var("a", "float32") 
+    b = tir.Var("b", "float32") 
+    c = a + b
+    c = tir.call_intrin("float32", "tir.myreturn", c) 
+    c = tir.Evaluate(c)
+    func = tir.PrimFunc([a, b], c)
+    mod = tvm.IRModule({'add': func})
+    func = tvm.build(mod['add'])
+    out = func(1.0, 2.0)
+    print(out)

Review comment:
       add an assert here




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org