You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/12/16 16:33:45 UTC

[GitHub] [tvm] tkonolige commented on a change in pull request #7084: [TIR] Support Return in TIR

tkonolige commented on a change in pull request #7084:
URL: https://github.com/apache/tvm/pull/7084#discussion_r544443419



##########
File path: src/tir/transforms/make_packed_api.cc
##########
@@ -41,6 +41,56 @@
 namespace tvm {
 namespace tir {
 
+class ReturnRewriter : public StmtMutator {
+ public:
+  explicit ReturnRewriter(Var ret_var, Var ret_tcode) : ret_var_(ret_var), ret_tcode_(ret_tcode) {}
+
+  Stmt VisitStmt_(const EvaluateNode* node) override {
+    Stmt ret = StmtMutator::VisitStmt_(node);
+    const EvaluateNode* eval = ret.as<EvaluateNode>();
+    CHECK(eval);

Review comment:
       Please change these to ICHECK because they are internal errors. Also, please add an error message. Something along the lines of "Expected EvaluateNode, but got ...."

##########
File path: src/tir/transforms/make_packed_api.cc
##########
@@ -41,6 +41,56 @@
 namespace tvm {
 namespace tir {
 
+class ReturnRewriter : public StmtMutator {
+ public:
+  explicit ReturnRewriter(Var ret_var, Var ret_tcode) : ret_var_(ret_var), ret_tcode_(ret_tcode) {}
+
+  Stmt VisitStmt_(const EvaluateNode* node) override {
+    Stmt ret = StmtMutator::VisitStmt_(node);
+    const EvaluateNode* eval = ret.as<EvaluateNode>();
+    CHECK(eval);
+    if (const CallNode* call = eval->value.as<CallNode>()) {
+      if (call->op.same_as(builtin::ret())) {
+        CHECK_EQ(call->args.size(), 1);

Review comment:
       ICHECK_EQ with error message

##########
File path: src/tir/transforms/make_packed_api.cc
##########
@@ -41,6 +41,56 @@
 namespace tvm {
 namespace tir {
 
+class ReturnRewriter : public StmtMutator {
+ public:
+  explicit ReturnRewriter(Var ret_var, Var ret_tcode) : ret_var_(ret_var), ret_tcode_(ret_tcode) {}
+
+  Stmt VisitStmt_(const EvaluateNode* node) override {
+    Stmt ret = StmtMutator::VisitStmt_(node);
+    const EvaluateNode* eval = ret.as<EvaluateNode>();
+    CHECK(eval);
+    if (const CallNode* call = eval->value.as<CallNode>()) {
+      if (call->op.same_as(builtin::ret())) {
+        CHECK_EQ(call->args.size(), 1);
+        ret = WriteToOut(call->args[0], ret_var_, ret_tcode_);
+      }
+    }
+    return ret;
+  }
+
+ private:
+  std::pair<int, PrimExpr> ConvertForFFI(PrimExpr val) {
+    DataType dtype = val.dtype();
+    if (dtype.is_int() || dtype.is_uint()) {
+      return {kTVMArgInt, Cast(DataType::Int(64), val)};
+    } else if (dtype.is_float()) {
+      return {kTVMArgFloat, Cast(DataType::Float(64), val)};
+    } else if (dtype.is_void()) {
+      return {kTVMNullptr, val};
+    } else {
+      LOG(FATAL) << "data type " << dtype << " not supported yet";

Review comment:
       Please add which data datatypes are supported to this error message




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org