You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/04/27 15:53:39 UTC

[GitHub] [tvm] lhutton1 commented on a diff in pull request #11091: [AOT] Enable A-Normal Form in the AOT executor

lhutton1 commented on code in PR #11091:
URL: https://github.com/apache/tvm/pull/11091#discussion_r859966101


##########
src/relay/backend/aot_executor_codegen.cc:
##########
@@ -126,7 +126,14 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor {
     for (const auto& param : func_node->params) {
       CreateStorage(param.get());
     }
-    GetStorage(func_node->body);
+    StorageInfo si = GetStorage(func_node->body);
+
+    // If the final expr could not be found it means it was let bound,

Review Comment:
   I recall trying this but moved away from the idea due to getting invalid outputs, but I agree it would tidy things up a bit.. I intended to have another look the past couple of days but ended up being busy with other things before I go away, I'll take another look when I return :)



##########
src/relay/backend/contrib/cmsisnn/relay_to_tir.cc:
##########
@@ -655,19 +655,61 @@ class RelayToTIRVisitor : public MixedModeMutator {
     return Call(new_global_var, call->args, call->attrs, call->type_args, call->span);
   }
 
-  Expr Rewrite_(const CallNode* pre, const Expr& post) override {
-    if (const CallNode* call = post.as<CallNode>()) {
-      auto* func = call->op.as<FunctionNode>();
-      if (func == nullptr) {
-        return post;
+  Expr VisitExpr_(const LetNode* op) final {
+    auto pre_visit = [this](const LetNode* op) {
+      Expr var = this->VisitExpr(op->var);
+      Expr value = this->VisitExpr(op->value);
+      // outlineable function no longer needs let binding
+      if (this->CanOutlineExpr(value)) {
+        this->memo_[var] = value;
+      }
+    };
+    auto post_visit = [this](const LetNode* op) {
+      // Rely on the Memoizer to cache pre-visit values
+      Expr value = this->VisitExpr(op->value);
+      Expr body = this->VisitExpr(op->body);
+      auto expr = GetRef<Expr>(op);
+      // drop the let binding
+      if (this->CanOutlineExpr(value)) {
+        this->memo_[expr] = this->VisitExpr(op->body);
+      } else {
+        Var var = Downcast<Var>(this->VisitExpr(op->var));
+        if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) {
+          this->memo_[expr] = expr;
+        } else {
+          this->memo_[expr] = Let(var, value, body);
+        }
       }
+    };
+    ExpandANormalForm(op, pre_visit, post_visit);
+    return memo_[GetRef<Expr>(op)];
+  }
 
-      auto codegen_name = func->GetAttr<String>(attr::kCompiler);
-      if (codegen_name.defined() && codegen_name == "cmsis-nn") {
-        const CallNode* inner_call = func->body.as<CallNode>();
+  bool CanOutlineExpr(const Expr& expr) {
+    // TODO(@lhutton1): This behaviour is similar to the OutlineCompilerFunctions pass
+    // we could reuse this functionality by separating outlining and lowering in this
+    // pass.

Review Comment:
   It was just an observation that we could remove code duplication using the `OutlineCompilerFunctions` pass which I wrote initially for the microNPU, but its backend agnostic :) this way the workaround for the let bindings here is only needed once



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org