You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/11/10 01:25:00 UTC

[GitHub] [tvm] mbs-octoml opened a new pull request #9483: [DRAFT] Scratch pad for switching VM to use the LowerTE pass.

mbs-octoml opened a new pull request #9483:
URL: https://github.com/apache/tvm/pull/9483


   **Absolutely not ready for review.**
   
   This is an experiment to see how gross things get if we switch the VM away from lower-one-func-at-a-time to using the the LowerTE pass.
   
   The answer looks to be: pretty gross, but doable.
   
   I don't want to go further without getting the calling conventions cleaned up, but this seems to handle:
    - ordinary code
    - device_copy (desperately needs to be cleaned up, prob using externs)
    - reshape
    - dynamic shape functions
   Lot's of passes don't support running post lowering without hackery which can all be improved.
   
   Still working through test_vm.py, let alone other unit tests.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbs-octoml commented on a change in pull request #9483: [Relay] Switch the VM to use the LowerTE pass instead of TECompiler::{Lower,LowerShapeFunc}.

Posted by GitBox <gi...@apache.org>.
mbs-octoml commented on a change in pull request #9483:
URL: https://github.com/apache/tvm/pull/9483#discussion_r759565993



##########
File path: src/relay/backend/vm/compiler.cc
##########
@@ -540,65 +505,60 @@ class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
     }
 
     for (auto output : output_tuple->fields) {
+      ICHECK(output->IsInstance<VarNode>()) << "output should be var, found:" << std::endl
+                                            << PrettyPrint(output);
       auto reg = var_register_map_.find(Downcast<Var>(output));
       ICHECK(reg != var_register_map_.end())
           << "internal error: all variables should be in the register mapping";
       argument_registers.push_back(reg->second);
     }
 
-    Target target;
-
-    // Which target should execute the function?
-    if (func->GetAttr<String>(attr::kCompiler).defined()) {
-      target = Target("ext_dev");
+    Index op_index;
+    auto itr = context_->primitive_map.find(global_var_node->name_hint);
+    if (itr == context_->primitive_map.end()) {
+      op_index = context_->primitive_map.size();
+      context_->primitive_map.emplace(global_var_node->name_hint, op_index);
     } else {
-      target = se_scope->target;
+      op_index = itr->second;
     }
-    ICHECK(target.defined()) << "No target for function:" << std::endl << PrettyPrint(func);
-
-    tec::CCacheKey key(func, target);
-    auto mangle_fn = [](String name) { return name; };
-    auto cfunc = context_->compiler->Lower(key, mangle_fn);  // <<<< one-func-at-a-time lowering
 
-    auto op_index = -1;
-    if (func->GetAttr<String>(attr::kCompiler).defined()) {
-      op_index = context_->cached_funcs.size();
-      context_->cached_funcs.push_back(cfunc);
-    } else {
-      // TODO(jroesch): support lowered funcs for multiple targets
-      ICHECK_EQ(cfunc->funcs->functions.size(), 1);
-      auto pfunc = Downcast<tir::PrimFunc>((*cfunc->funcs->functions.begin()).second);
-      if (context_->seen_funcs.find(pfunc) == context_->seen_funcs.end()) {
-        op_index = context_->cached_funcs.size();
-        context_->cached_funcs.push_back(cfunc);
-        context_->seen_funcs[pfunc] = op_index;
-      } else {
-        op_index = context_->seen_funcs[pfunc];
-      }
-    }
-
-    // Extract functions attrs
-    op_attrs[op_index] = func->attrs->dict;
+    // Capture the dictionary of attributes from the original primitive function so that they

Review comment:
       I think that's a separate issue related to versioning of kernels etc. The change here is preserving the existing behaviour.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch commented on a change in pull request #9483: [Relay] Switch the VM to use the LowerTE pass instead of TECompiler::{Lower,LowerShapeFunc}.

Posted by GitBox <gi...@apache.org>.
jroesch commented on a change in pull request #9483:
URL: https://github.com/apache/tvm/pull/9483#discussion_r758750447



##########
File path: src/relay/backend/vm/compiler.cc
##########
@@ -1125,55 +1105,52 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) {
   }
 }
 
-void VMCompiler::PopulateGlobalMap() {
-  // First we populate global map.
-  size_t global_index = 0;
-  for (auto named_func : context_.module->functions) {
-    auto gvar = named_func.first;
-    context_.global_map.insert({gvar, global_index++});
+size_t VMCompiler::PopulateGlobalMap() {
+  // Allocate a VMFunction index for every Relay Function we could call.
+  // Excludes PrimFuncs and externs, which are managed by the primitive_map_.
+  for (const auto& kv : context_.module->functions) {
+    if (const auto* function_node = kv.second.as<FunctionNode>()) {
+      if (!function_node->GetAttr<String>(attr::kExternalSymbol)) {
+        context_.global_map.emplace(kv.first, context_.global_map.size());
+      }
+    }
   }
+  return context_.global_map.size();
 }
 
 void VMCompiler::Codegen() {
+  VLOG_CONTEXT << "VM Codegen";
   if (!context_.module.defined()) {
-    LOG(WARNING) << "Did you forget to call VMCompiler::Lower?";
-    return;
-  }
-  auto const& cached_funcs = context_.cached_funcs;
-  if (cached_funcs.size() == 0) {
+    LOG(WARNING) << "No compiled module to codegen from. Did you forget to call VMCompiler::Lower?";
     return;
   }
-  Map<Target, IRModule> funcs;
-
-  for (auto& cfunc : cached_funcs) {
-    Target target = cfunc->target;
-    // NOTE: because module, is mutable, we need to make an
-    // explicit copy of the IRModule.
-    IRModule mod = cfunc->funcs;
-    mod.CopyOnWrite();
-
-    if (target->kind->device_type == kDLExtDev) {
-      // Collect metadata in functions that are handled by external codegen.
-      auto name = cfunc->prim_fn_var->name_hint;
-      ICHECK(mod->ContainGlobalVar(name));
-      backend::UpdateConstants(mod->Lookup(name), &params_);
-    } else if (funcs.count(target) == 0) {
-      funcs.Set(target, mod);
-    } else {
-      funcs[target]->Update(mod);
-    }
-  }
 
-  auto ext_mods = context_.compiler->LowerExternalFunctions();
+  // At this point context_.module will contain only:

Review comment:
       Do we require the final build refactors to not require splitting here?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbs-octoml commented on a change in pull request #9483: [Relay] Switch the VM to use the LowerTE pass instead of TECompiler::{Lower,LowerShapeFunc}.

Posted by GitBox <gi...@apache.org>.
mbs-octoml commented on a change in pull request #9483:
URL: https://github.com/apache/tvm/pull/9483#discussion_r759589257



##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -224,23 +271,28 @@ class TECompilerImpl : public TECompilerNode {
     }
     cur_ccache_key_ = key;
 
-    // No need to lower external functions for now. We will invoke the external
-    // codegen tool once and lower all functions together.
-    if (key->source_func->GetAttr<String>(attr::kCompiler).defined()) {
-      auto ir_module = IRModule();
-      const auto name_node = key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
-      ICHECK(name_node.defined()) << "External function has not been attached a name yet.";
-      auto func_name = GetUniqueName(name_node.value(), &name_map_);
+    Optional<String> opt_compiler = key->source_func->GetAttr<String>(attr::kCompiler);
+    if (opt_compiler.defined()) {
+      // Don't compile now since we don't have anywhere to put the resulting runtime module.
+      // Instead place the original definition in the cache and wait for LowerExternalFunctions.
+      IRModule ir_module;
+      // Note that the source_func may already be bound to a global function in the module

Review comment:
       I moved the comment down a bit -- it is really talking about that we are using opt_global_symbol.value() as *the* global name and not attempting to rename. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbs-octoml commented on a change in pull request #9483: [Relay] Switch the VM to use the LowerTE pass instead of TECompiler::{Lower,LowerShapeFunc}.

Posted by GitBox <gi...@apache.org>.
mbs-octoml commented on a change in pull request #9483:
URL: https://github.com/apache/tvm/pull/9483#discussion_r759750028



##########
File path: src/relay/backend/vm/compiler.cc
##########
@@ -481,50 +484,12 @@ class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
     this->last_register_ = merge_register;
   }
 
-  void EmitShapeFunc(Function func, Array<Expr> inputs, Array<Expr> outputs) {

Review comment:
       Yeah I deserve a badge for that one.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbs-octoml commented on a change in pull request #9483: [Relay] Switch the VM to use the LowerTE pass instead of TECompiler::{Lower,LowerShapeFunc}.

Posted by GitBox <gi...@apache.org>.
mbs-octoml commented on a change in pull request #9483:
URL: https://github.com/apache/tvm/pull/9483#discussion_r759752626



##########
File path: src/relay/transforms/device_planner.cc
##########
@@ -765,7 +765,8 @@ class DeviceCapturer : public ExprMutator {
 
   IRModule Capture() {
     VLOG_CONTEXT << "CaptureDevices";
-    IRModule result(/*functions=*/{}, mod_->type_definitions, mod_->Imports(), mod_->source_map);
+    IRModule result(/*functions=*/{}, mod_->type_definitions, mod_->Imports(), mod_->source_map,

Review comment:
       Yes! Thanks for reminding me, filed CORE-126.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch commented on a change in pull request #9483: [Relay] Switch the VM to use the LowerTE pass instead of TECompiler::{Lower,LowerShapeFunc}.

Posted by GitBox <gi...@apache.org>.
jroesch commented on a change in pull request #9483:
URL: https://github.com/apache/tvm/pull/9483#discussion_r758744300



##########
File path: src/relay/transforms/inline.cc
##########
@@ -54,22 +56,36 @@ class Inliner : ExprMutator {
       : cur_node_(cur_node), call_graph_(call_graph) {}
 
   Expr VisitExpr_(const CallNode* call_node) final {
-    Expr op = call_node->op;
-    const auto* gvn = op.as<GlobalVarNode>();
+    // We can work with calls in both Relay and call_lowered form.
+    Array<Expr> args;
+    Expr op;

Review comment:
       It might be good to just insert checks below here incase the code is later refactored in a way that would make these be-unitialized. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbs-octoml commented on a change in pull request #9483: [Relay] Switch the VM to use the LowerTE pass instead of TECompiler::{Lower,LowerShapeFunc}.

Posted by GitBox <gi...@apache.org>.
mbs-octoml commented on a change in pull request #9483:
URL: https://github.com/apache/tvm/pull/9483#discussion_r759577578



##########
File path: src/relay/analysis/call_graph.cc
##########
@@ -64,9 +67,19 @@ void CallGraphNode::AddToCallGraph(const GlobalVar& gv, const Function& func) {
   // post-order visitor will visit each AST node of the current function to
   // figure out the dependencies between functions.
   PostOrderVisit(func, [&](const Expr& expr) {
-    if (const GlobalVarNode* gvn = expr.as<GlobalVarNode>()) {
-      auto callee = GetRef<GlobalVar>(gvn);
-      cg_node->AddCalledGlobal(LookupGlobalVar(callee));
+    // TODO(mbs): Cleanup shapes functions.
+    if (const auto* call_node = expr.as<CallNode>()) {
+      CallLoweredProps props = GetCallLoweredProps(call_node);
+      if (props.lowered_func.defined() && props.attrs.metadata.count("prim_shape_fn_var")) {

Review comment:
       Agree, in general these free-form attrs should become first-class fields, but leaving that to a future refactor.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch commented on a change in pull request #9483: [Relay] Switch the VM to use the LowerTE pass instead of TECompiler::{Lower,LowerShapeFunc}.

Posted by GitBox <gi...@apache.org>.
jroesch commented on a change in pull request #9483:
URL: https://github.com/apache/tvm/pull/9483#discussion_r758632547



##########
File path: src/relay/backend/graph_executor_codegen.cc
##########
@@ -215,24 +215,35 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
 
     if (memory_plan_.defined()) {
       // TODO(@electriclilies, @jroesch): remove UpdateMainWorkspaceSize
-      func_info =
-          relay::tec::UpdateMainWorkspaceSize(mod, targets_, memory_plan_->expr_to_storage_info);
+      tec::TargetMap tec_target_map;
+      for (const auto& pair : targets_) {
+        tec_target_map.emplace(static_cast<DLDeviceType>(pair.first->value), pair.second);
+      }
+      func_info = relay::tec::UpdateMainWorkspaceSize(mod, tec_target_map,

Review comment:
       Do we have a tracking issue for flowing this into the actual lowering machinery? 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch commented on a change in pull request #9483: [Relay] Switch the VM to use the LowerTE pass instead of TECompiler::{Lower,LowerShapeFunc}.

Posted by GitBox <gi...@apache.org>.
jroesch commented on a change in pull request #9483:
URL: https://github.com/apache/tvm/pull/9483#discussion_r758749457



##########
File path: src/relay/backend/vm/compiler.cc
##########
@@ -540,65 +505,60 @@ class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
     }
 
     for (auto output : output_tuple->fields) {
+      ICHECK(output->IsInstance<VarNode>()) << "output should be var, found:" << std::endl
+                                            << PrettyPrint(output);
       auto reg = var_register_map_.find(Downcast<Var>(output));
       ICHECK(reg != var_register_map_.end())
           << "internal error: all variables should be in the register mapping";
       argument_registers.push_back(reg->second);
     }
 
-    Target target;
-
-    // Which target should execute the function?
-    if (func->GetAttr<String>(attr::kCompiler).defined()) {
-      target = Target("ext_dev");
+    Index op_index;
+    auto itr = context_->primitive_map.find(global_var_node->name_hint);
+    if (itr == context_->primitive_map.end()) {
+      op_index = context_->primitive_map.size();
+      context_->primitive_map.emplace(global_var_node->name_hint, op_index);
     } else {
-      target = se_scope->target;
+      op_index = itr->second;
     }
-    ICHECK(target.defined()) << "No target for function:" << std::endl << PrettyPrint(func);
-
-    tec::CCacheKey key(func, target);
-    auto mangle_fn = [](String name) { return name; };
-    auto cfunc = context_->compiler->Lower(key, mangle_fn);  // <<<< one-func-at-a-time lowering
 
-    auto op_index = -1;
-    if (func->GetAttr<String>(attr::kCompiler).defined()) {
-      op_index = context_->cached_funcs.size();
-      context_->cached_funcs.push_back(cfunc);
-    } else {
-      // TODO(jroesch): support lowered funcs for multiple targets
-      ICHECK_EQ(cfunc->funcs->functions.size(), 1);
-      auto pfunc = Downcast<tir::PrimFunc>((*cfunc->funcs->functions.begin()).second);
-      if (context_->seen_funcs.find(pfunc) == context_->seen_funcs.end()) {
-        op_index = context_->cached_funcs.size();
-        context_->cached_funcs.push_back(cfunc);
-        context_->seen_funcs[pfunc] = op_index;
-      } else {
-        op_index = context_->seen_funcs[pfunc];
-      }
-    }
-
-    // Extract functions attrs
-    op_attrs[op_index] = func->attrs->dict;
+    // Capture the dictionary of attributes from the original primitive function so that they

Review comment:
       This seems kind of leaky, do we have a better way to do this?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbs-octoml commented on a change in pull request #9483: [Relay] Switch the VM to use the LowerTE pass instead of TECompiler::{Lower,LowerShapeFunc}.

Posted by GitBox <gi...@apache.org>.
mbs-octoml commented on a change in pull request #9483:
URL: https://github.com/apache/tvm/pull/9483#discussion_r759579098



##########
File path: src/relay/backend/aot_executor_codegen.cc
##########
@@ -314,9 +316,9 @@ class AOTExecutorCodegen : public MixedModeVisitor {
   /*!
    * brief Create a function call
    * \param call_lowered_props The lowered function and the arguments to call it with
-   * \param call The call we got func and args from
+   * \param call_node The call we got func and args from
    */
-  void CreateFuncCall(CallLoweredProps call_lowered_props, Call call) {
+  void CreateFuncCall(CallLoweredProps call_lowered_props, const Expr& result_expr) {

Review comment:
       Thanks, apparently not a lint check :-(
   Done.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbs-octoml commented on pull request #9483: [Relay] Switch the VM to use the LowerTE pass instead of TECompiler::{Lower,LowerShapeFunc}.

Posted by GitBox <gi...@apache.org>.
mbs-octoml commented on pull request #9483:
URL: https://github.com/apache/tvm/pull/9483#issuecomment-983137065


   Thanks for slogging through @jroesch  and @electriclilies. The next few should be more manageable. Or at least that's what I tell myself.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbs-octoml commented on a change in pull request #9483: [Relay] Switch the VM to use the LowerTE pass instead of TECompiler::{Lower,LowerShapeFunc}.

Posted by GitBox <gi...@apache.org>.
mbs-octoml commented on a change in pull request #9483:
URL: https://github.com/apache/tvm/pull/9483#discussion_r759752011



##########
File path: src/relay/op/memory/device_copy.cc
##########
@@ -117,15 +117,5 @@ DeviceCopyProps GetDeviceCopyProps(const Expr& expr) {
   return {};
 }
 
-DeviceCopyProps GetLoweredDeviceCopyProps(const CallLoweredProps& props) {

Review comment:
       :-)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbs-octoml commented on a change in pull request #9483: [Relay] Switch the VM to use the LowerTE pass instead of TECompiler::{Lower,LowerShapeFunc}.

Posted by GitBox <gi...@apache.org>.
mbs-octoml commented on a change in pull request #9483:
URL: https://github.com/apache/tvm/pull/9483#discussion_r759748359



##########
File path: src/relay/backend/te_compiler_cache.cc
##########
@@ -471,10 +481,17 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>>
     With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
 
     std::unordered_map<te::Tensor, tir::Buffer> binds;
-    IRModule ir_module = tvm::LowerSchedule(schedule, all_args, func_name, binds);
-
+    IRModule lowered_module = tvm::LowerSchedule(schedule, all_args, func_name, binds);
+
+    IRModule fixed_lowered_module;
+    for (const auto& kv : lowered_module->functions) {

Review comment:
       Good catch -- forgot to comment that. Guess I was getting a bit desperate.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbs-octoml commented on a change in pull request #9483: [Relay] Switch the VM to use the LowerTE pass instead of TECompiler::{Lower,LowerShapeFunc}.

Posted by GitBox <gi...@apache.org>.
mbs-octoml commented on a change in pull request #9483:
URL: https://github.com/apache/tvm/pull/9483#discussion_r758649902



##########
File path: src/relay/backend/graph_executor_codegen.cc
##########
@@ -215,24 +215,35 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
 
     if (memory_plan_.defined()) {
       // TODO(@electriclilies, @jroesch): remove UpdateMainWorkspaceSize
-      func_info =
-          relay::tec::UpdateMainWorkspaceSize(mod, targets_, memory_plan_->expr_to_storage_info);
+      tec::TargetMap tec_target_map;
+      for (const auto& pair : targets_) {
+        tec_target_map.emplace(static_cast<DLDeviceType>(pair.first->value), pair.second);
+      }
+      func_info = relay::tec::UpdateMainWorkspaceSize(mod, tec_target_map,

Review comment:
       CORE-122




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbs-octoml commented on a change in pull request #9483: [Relay] Switch the VM to use the LowerTE pass instead of TECompiler::{Lower,LowerShapeFunc}.

Posted by GitBox <gi...@apache.org>.
mbs-octoml commented on a change in pull request #9483:
URL: https://github.com/apache/tvm/pull/9483#discussion_r759582385



##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -134,41 +144,78 @@ class TECompilerImpl : public TECompilerNode {
     return mod;
   }
 
+  void AddExterns(IRModule module) {
+    // Everything tagged with "Compiler" has been compiled, so remove those definitions.
+    std::vector<GlobalVar> to_be_deleted;
+    for (const auto& kv : module->functions) {
+      if (kv.second->GetAttr<String>(attr::kCompiler).defined()) {
+        to_be_deleted.push_back(kv.first);
+      }
+    }
+    for (const auto& global_var : to_be_deleted) {
+      module->Remove(global_var);
+    }
+    // HOWEVER we still need a Relay definition to go with those now external functions, so
+    // retrieve them from the cache and mark them with "ExternalSymbol".

Review comment:
       The now-compiled globals have just been removed. This second loop is looking for the copies of those stored away inside the cache entry's 'funcs' IRModule. It was like this when I got here.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] electriclilies commented on a change in pull request #9483: [Relay] Switch the VM to use the LowerTE pass instead of TECompiler::{Lower,LowerShapeFunc}.

Posted by GitBox <gi...@apache.org>.
electriclilies commented on a change in pull request #9483:
URL: https://github.com/apache/tvm/pull/9483#discussion_r759706150



##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -481,206 +557,217 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator {
    */
   Expr MakeLoweredCall(Function func, Array<Expr> visited_args, Array<Type> type_args, Span span,
                        Target target) {
-    if (func->GetAttr<String>(attr::kCompiler).defined()) {
-      // BYOC flow.
-      CCacheKey key = CCacheKey(func, target);
-      CachedFunc ext_func = compiler_->Lower(key, module_name_);
-      ICHECK(ext_func.defined()) << "Lowering returned undefined function for "
-                                 << ext_func->prim_fn_var->name_hint;
-
-      // TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT
-      Map<GlobalVar, tir::PrimFunc> prim_fns;
-      relay::Function func_with_metadata = func;
-      func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", ext_func->prim_fn_var);
-      func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns);
-      func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, ext_func->target);
-
-      // Provide a callback hook which allows one-level up code generators to
-      // act when we process a function.
-      this->process_fn_(func_with_metadata);
-
-      // TODO(mbs): Dynamic shapes?
-      // TODO(@mbs, electriclilies): Make extern functions explicit
-      return Call(ext_func->prim_fn_var, visited_args, Attrs(), type_args, span);
-
-    } else {
-      // Non-External Relay Function
-      CCacheKey key = CCacheKey(func, target);
-      CachedFunc lowered_func = compiler_->Lower(key, module_name_);
-
-      // Collect all the lowered functions produced for this primitive function.
-      Map<GlobalVar, tir::PrimFunc> prim_fns;
-      Array<GlobalVar> all_prim_fn_vars;
-      for (auto prim_fn : lowered_func->funcs->functions) {
-        CHECK(prim_fn.second.as<tir::PrimFuncNode>()) << "must be a prim fn";
-        prim_fns.Set(prim_fn.first, Downcast<tir::PrimFunc>(prim_fn.second));
-        all_prim_fn_vars.push_back(prim_fn.first);
-      }
-
-      // TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT
-      relay::Function func_with_metadata = func;
-      func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", lowered_func->prim_fn_var);
-      func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns);
-      func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, lowered_func->target);
-
-      // Provide a callback hook which allows one-level up code generators to
-      // act when we process a function.
-      this->process_fn_(func_with_metadata);
-
-      auto call_lowered_attrs = make_object<CallLoweredAttrs>();
-      if (func->HasNonzeroAttr(attr::kReshapeOnly)) {
-        call_lowered_attrs->metadata.Set(attr::kReshapeOnly, tvm::Integer(1));
-      }
-
-      DeviceCopyProps props = GetDeviceCopyProps(func);
-      if (props.body.defined()) {
-        // Record the device copy source and destination scopes so the device planner can
-        // still follow along even after lowering.
-        call_lowered_attrs->metadata.Set("src_se_scope", props.src_se_scope);
-        call_lowered_attrs->metadata.Set("dst_se_scope", props.dst_se_scope);
+    CCacheKey key = CCacheKey(func, target);
+    CachedFunc cfunc = compiler_->Lower(key, module_name_);
+    ICHECK(cfunc.defined());
+
+    auto opt_compiler = func->GetAttr<String>(attr::kCompiler);
+
+    // Add some metadata on top of the *original function* and invoke the callback so it can
+    // be captured.
+    // TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT
+    Map<GlobalVar, tir::PrimFunc> prim_fns;
+    Array<GlobalVar> all_prim_fn_vars;
+    for (const auto& kv : cfunc->funcs->functions) {
+      if (opt_compiler) {
+        // We expect just the original func but with just the ExternalSymbol attribute signalling

Review comment:
       typo: signalling -> signaling

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -951,22 +1039,87 @@ void UpdateFunctionMetadata(BaseFunc func,
   function_metadata.Set(prim_fn_var.value()->name_hint, fi);
 }
 
-IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn process_fn) {
-  TECompiler compiler;
-
-  auto updated_module = LowerTensorExpr(module_name, compiler, process_fn)(module);
-
-  backend::UpdateAutoSchedulerOpWeights(compiler);
+IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn process_fn,
+                 SEScope host_se_scope) {
+  TECompiler compiler(module);
+
+  // TODO(mbs): This is all unnecessarily convoluted. Better would be to accumulate the rewritten
+  // module as we go (including rewritten Functions, lowered primitives, and runtime modules
+  // generated by external toolchains), and use a pair of maps over vars and global vars
+  // to global vars to remember which functions have already been lowered.
+
+  // Lower all the callees in module:
+  //  - Functions tagged with "Compiler" are unchanged (checked by CreateFunctionPass)
+  //  - Functions tagged with "Primitive" are unchanged (checked by LowerTensorExprMutator)
+  //  - Called functions tagged with "Compiler" are copied into the compiler cache with a fresh
+  //    GlobalVar, and calls updated (sticking with regular Relay Call).
+  //  - Calls to functions tagged with "Primitive" are compiled to PrimFuncs, and calls updated
+  //    (using call_lowered convention).
+  IRModule updated_module = LowerTensorExpr(module_name, compiler, std::move(process_fn),
+                                            std::move(host_se_scope))(module);
+
+  // The Functions tagged with "Compiler" are now residing in the cache ready to be
+  // compiled by LowerExternalFunctions. However we still need a record of them in the
+  // IRModule so that the various executors can see which function names need to be
+  // retrieved. They may, however, have been renamed.
+  compiler->AddExterns(updated_module);
+
+  // Add the lowered functions.
+  IRModule lowered_module = compiler->GetLoweredFunctions();
+  VLOG(1) << "capturing " << lowered_module->functions.size() << " new lowered functions";
+  for (const auto& kv : lowered_module->functions) {
+    if (updated_module->ContainGlobalVar(kv.first->name_hint)) {
+      LOG(FATAL) << "duplicate bindings for '" << kv.first->name_hint
+                 << "'. Existing is:" << std::endl
+                 << PrettyPrint(updated_module->Lookup(kv.first->name_hint)) << std::endl
+                 << "while new is:" << std::endl
+                 << PrettyPrint(kv.second);
+    }
+    updated_module->Add(kv.first, kv.second);
+  }
 
-  // Copy the lowered functions into the return module
-  updated_module->Update(compiler->GetLoweredFunctions());
+  // Invoke external codegen for all Functions in the cache tagged with "Compiler", and
+  // annotate the module with the resulting runtime modules.
+  // TODO(mbs): runtime modules should be first class rather than attributes.
+  Array<runtime::Module> external_mods =
+      module->GetAttr<Array<runtime::Module>>("external_mods", Array<runtime::Module>()).value();
+  Array<runtime::Module> new_external_mods = compiler->LowerExternalFunctions();
+  VLOG(1) << "capturing " << external_mods.size() << " existing and " << new_external_mods.size()
+          << " new external modules";
+  for (const auto& mod : new_external_mods) {
+    external_mods.push_back(mod);  // copy-on-write.
+  }
 
-  // Annotate the module with C Device API context mapping, the external modules and function info
-  // this is until we have Target's annotated for the C Device API
+  // Annotate the module with C Device API context mapping (this is until we have Target's
+  // annotated for the C Device API)
   // TODO(Mousius) - Remove "device_contexts" as soon as we have the graph annotated properly with
   // Target's

Review comment:
       again Target's should be Targets or targets

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -951,22 +1039,87 @@ void UpdateFunctionMetadata(BaseFunc func,
   function_metadata.Set(prim_fn_var.value()->name_hint, fi);
 }
 
-IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn process_fn) {
-  TECompiler compiler;
-
-  auto updated_module = LowerTensorExpr(module_name, compiler, process_fn)(module);
-
-  backend::UpdateAutoSchedulerOpWeights(compiler);
+IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn process_fn,
+                 SEScope host_se_scope) {
+  TECompiler compiler(module);
+
+  // TODO(mbs): This is all unnecessarily convoluted. Better would be to accumulate the rewritten
+  // module as we go (including rewritten Functions, lowered primitives, and runtime modules
+  // generated by external toolchains), and use a pair of maps over vars and global vars
+  // to global vars to remember which functions have already been lowered.
+
+  // Lower all the callees in module:
+  //  - Functions tagged with "Compiler" are unchanged (checked by CreateFunctionPass)
+  //  - Functions tagged with "Primitive" are unchanged (checked by LowerTensorExprMutator)
+  //  - Called functions tagged with "Compiler" are copied into the compiler cache with a fresh
+  //    GlobalVar, and calls updated (sticking with regular Relay Call).
+  //  - Calls to functions tagged with "Primitive" are compiled to PrimFuncs, and calls updated
+  //    (using call_lowered convention).
+  IRModule updated_module = LowerTensorExpr(module_name, compiler, std::move(process_fn),
+                                            std::move(host_se_scope))(module);
+
+  // The Functions tagged with "Compiler" are now residing in the cache ready to be
+  // compiled by LowerExternalFunctions. However we still need a record of them in the
+  // IRModule so that the various executors can see which function names need to be
+  // retrieved. They may, however, have been renamed.
+  compiler->AddExterns(updated_module);
+
+  // Add the lowered functions.
+  IRModule lowered_module = compiler->GetLoweredFunctions();
+  VLOG(1) << "capturing " << lowered_module->functions.size() << " new lowered functions";
+  for (const auto& kv : lowered_module->functions) {
+    if (updated_module->ContainGlobalVar(kv.first->name_hint)) {
+      LOG(FATAL) << "duplicate bindings for '" << kv.first->name_hint
+                 << "'. Existing is:" << std::endl
+                 << PrettyPrint(updated_module->Lookup(kv.first->name_hint)) << std::endl
+                 << "while new is:" << std::endl
+                 << PrettyPrint(kv.second);
+    }
+    updated_module->Add(kv.first, kv.second);
+  }
 
-  // Copy the lowered functions into the return module
-  updated_module->Update(compiler->GetLoweredFunctions());
+  // Invoke external codegen for all Functions in the cache tagged with "Compiler", and
+  // annotate the module with the resulting runtime modules.
+  // TODO(mbs): runtime modules should be first class rather than attributes.
+  Array<runtime::Module> external_mods =
+      module->GetAttr<Array<runtime::Module>>("external_mods", Array<runtime::Module>()).value();
+  Array<runtime::Module> new_external_mods = compiler->LowerExternalFunctions();
+  VLOG(1) << "capturing " << external_mods.size() << " existing and " << new_external_mods.size()
+          << " new external modules";
+  for (const auto& mod : new_external_mods) {
+    external_mods.push_back(mod);  // copy-on-write.
+  }
 
-  // Annotate the module with C Device API context mapping, the external modules and function info
-  // this is until we have Target's annotated for the C Device API
+  // Annotate the module with C Device API context mapping (this is until we have Target's

Review comment:
       typo: Target's -> Targets

##########
File path: src/relay/backend/te_compiler_cache.cc
##########
@@ -427,26 +414,49 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>>
       candidate_name = truncated_name.str();
     }
 
-    // Set all the inputs correctly.
+    // Set all the inputs correctly, and accumulate their types from the p.o.v. of the
+    // shape function rather than the primitive it is derived for.
     Array<te::Tensor> inputs;
+    Array<Type> shape_function_arg_types;
     for (auto param : prim_func->params) {
       int state = param_states_[param];
       shape_func_param_states.push_back(IntImm(DataType::Int(32), state));
       if (state & kNeedInputData) {
+        // Pass the primitive arguments directly (though in flattened form and on the host)
         for (auto t : param_data_[param]) {
           inputs.push_back(t);
+          shape_function_arg_types.push_back(TensorType(t->GetShape(), t->GetDataType()));
         }
       }
       if (state & kNeedInputShape) {
+        // Pass the shapes of the primitive arguments (also on the host)
         for (auto t : param_shapes_[param]) {
           inputs.push_back(t);
+          shape_function_arg_types.push_back(TensorType(t->GetShape(), t->GetDataType()));
         }
       }
     }
 
+    // TODO(mbs): This should be the definitive global by which the PrimFunc is known and
+    // no  other GlobalVar ctors should appear inside the lowering machinery.
     auto func_name = renamer(candidate_name);
     auto prim_fn_gvar = GlobalVar(func_name);
-    prim_fn_gvar->checked_type_ = prim_func->checked_type();
+
+    // Gather the result types, again from the p.o.v. of the shape function rather than
+    // the primitive it is derived for.
+    Array<Type> shape_function_res_types;
+    for (const auto& t : outputs) {
+      shape_function_res_types.push_back(TensorType(t->GetShape(), t->GetDataType()));
+    }
+
+    // Assign the shape function it's true type.

Review comment:
       typo: it's -> its

##########
File path: src/relay/backend/te_compiler_cache.cc
##########
@@ -471,10 +481,17 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>>
     With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
 
     std::unordered_map<te::Tensor, tir::Buffer> binds;
-    IRModule ir_module = tvm::LowerSchedule(schedule, all_args, func_name, binds);
-
+    IRModule lowered_module = tvm::LowerSchedule(schedule, all_args, func_name, binds);
+
+    IRModule fixed_lowered_module;
+    for (const auto& kv : lowered_module->functions) {

Review comment:
       Why do we need to fix the lowered module?

##########
File path: src/relay/backend/vm/compiler.cc
##########
@@ -1051,6 +995,26 @@ transform::Sequential MemoryOpt(const SEScope& cpu_se_scope) {
   return transform::Sequential(std::move(pass_seqs));
 }
 
+transform::Sequential VMCompiler::LowerOperators(const SEScope& host_se_scope) {
+  Array<Pass> pass_seqs;

Review comment:
       Might be good to rename this to fuse and lower (or something that mentions fusion), in the places you used it you also have a note that it fuses as well as lowers.

##########
File path: src/relay/backend/vm/compiler.cc
##########
@@ -481,50 +484,12 @@ class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
     this->last_register_ = merge_register;
   }
 
-  void EmitShapeFunc(Function func, Array<Expr> inputs, Array<Expr> outputs) {

Review comment:
       Cool that you got rid of this!!

##########
File path: src/relay/backend/vm/compiler.cc
##########
@@ -1102,7 +1066,24 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) {
   pass_seqs.push_back(transform::ToANormalForm());
   pass_seqs.push_back(transform::InferType());
   pass_seqs.push_back(transform::LambdaLift());
-  pass_seqs.push_back(transform::InlinePrimitives());
+
+  // Eliminate dead-code before we lower. We don't track the purity of PrimFuncs, thus after
+  // lowering all calls to lowered functions will be kept.
+  pass_seqs.push_back(DeadCodeElimination(/*inline_once=*/false));
+  pass_seqs.push_back(transform::LabelOps());
+
+  // Lower all function's annotated as "primitive" by FuseOps.

Review comment:
       typo: function's -> functions

##########
File path: src/relay/transforms/type_infer.cc
##########
@@ -395,7 +400,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
   //
   // The result will be the return type of the operator.
   Type PrimitiveCall(const FuncTypeNode* op, Array<Type> arg_types, const Attrs& attrs,
-                     const Span& span) {
+                     const Span& span, const Expr& expr) {

Review comment:
       I think you don't use expr in this function at all? Also can you change op to func_type_node :-)

##########
File path: src/relay/transforms/device_planner.cc
##########
@@ -765,7 +765,8 @@ class DeviceCapturer : public ExprMutator {
 
   IRModule Capture() {
     VLOG_CONTEXT << "CaptureDevices";
-    IRModule result(/*functions=*/{}, mod_->type_definitions, mod_->Imports(), mod_->source_map);
+    IRModule result(/*functions=*/{}, mod_->type_definitions, mod_->Imports(), mod_->source_map,

Review comment:
       Hmm, maybe we should add a WithFields for IRModule. 

##########
File path: src/relay/backend/vm/compiler.cc
##########
@@ -481,50 +484,12 @@ class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
     this->last_register_ = merge_register;
   }
 
-  void EmitShapeFunc(Function func, Array<Expr> inputs, Array<Expr> outputs) {
-    // Lower shape function
-    tec::CCacheKey key(func, host_se_scope_->target);
-    auto cfunc = context_->compiler->LowerShapeFunc(key);
-    int op_index = -1;
-    // pick the only function inside the context
-    ICHECK_EQ(cfunc->funcs->functions.size(), 1);
-    auto pfunc = Downcast<tir::PrimFunc>((*cfunc->funcs->functions.begin()).second);
-    if (context_->seen_funcs.count(pfunc) == 0) {
-      op_index = context_->cached_funcs.size();
-      context_->cached_funcs.push_back(cfunc);
-      context_->seen_funcs[pfunc] = op_index;
-    } else {
-      op_index = context_->seen_funcs[pfunc];
-    }
-
-    // Prepare input and output registers
+  void EmitInvokeTVMOp(const Expr& func, const Expr& inputs, const Expr& outputs,

Review comment:
       Why don't we need to pass SEScope in here anymore?

##########
File path: src/relay/op/memory/device_copy.cc
##########
@@ -117,15 +117,5 @@ DeviceCopyProps GetDeviceCopyProps(const Expr& expr) {
   return {};
 }
 
-DeviceCopyProps GetLoweredDeviceCopyProps(const CallLoweredProps& props) {

Review comment:
       glad to see this simplification!

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -481,206 +557,217 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator {
    */
   Expr MakeLoweredCall(Function func, Array<Expr> visited_args, Array<Type> type_args, Span span,
                        Target target) {
-    if (func->GetAttr<String>(attr::kCompiler).defined()) {
-      // BYOC flow.
-      CCacheKey key = CCacheKey(func, target);
-      CachedFunc ext_func = compiler_->Lower(key, module_name_);
-      ICHECK(ext_func.defined()) << "Lowering returned undefined function for "
-                                 << ext_func->prim_fn_var->name_hint;
-
-      // TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT
-      Map<GlobalVar, tir::PrimFunc> prim_fns;
-      relay::Function func_with_metadata = func;
-      func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", ext_func->prim_fn_var);
-      func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns);
-      func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, ext_func->target);
-
-      // Provide a callback hook which allows one-level up code generators to
-      // act when we process a function.
-      this->process_fn_(func_with_metadata);
-
-      // TODO(mbs): Dynamic shapes?
-      // TODO(@mbs, electriclilies): Make extern functions explicit
-      return Call(ext_func->prim_fn_var, visited_args, Attrs(), type_args, span);
-
-    } else {
-      // Non-External Relay Function
-      CCacheKey key = CCacheKey(func, target);
-      CachedFunc lowered_func = compiler_->Lower(key, module_name_);
-
-      // Collect all the lowered functions produced for this primitive function.
-      Map<GlobalVar, tir::PrimFunc> prim_fns;
-      Array<GlobalVar> all_prim_fn_vars;
-      for (auto prim_fn : lowered_func->funcs->functions) {
-        CHECK(prim_fn.second.as<tir::PrimFuncNode>()) << "must be a prim fn";
-        prim_fns.Set(prim_fn.first, Downcast<tir::PrimFunc>(prim_fn.second));
-        all_prim_fn_vars.push_back(prim_fn.first);
-      }
-
-      // TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT
-      relay::Function func_with_metadata = func;
-      func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", lowered_func->prim_fn_var);
-      func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns);
-      func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, lowered_func->target);
-
-      // Provide a callback hook which allows one-level up code generators to
-      // act when we process a function.
-      this->process_fn_(func_with_metadata);
-
-      auto call_lowered_attrs = make_object<CallLoweredAttrs>();
-      if (func->HasNonzeroAttr(attr::kReshapeOnly)) {
-        call_lowered_attrs->metadata.Set(attr::kReshapeOnly, tvm::Integer(1));
-      }
-
-      DeviceCopyProps props = GetDeviceCopyProps(func);
-      if (props.body.defined()) {
-        // Record the device copy source and destination scopes so the device planner can
-        // still follow along even after lowering.
-        call_lowered_attrs->metadata.Set("src_se_scope", props.src_se_scope);
-        call_lowered_attrs->metadata.Set("dst_se_scope", props.dst_se_scope);
+    CCacheKey key = CCacheKey(func, target);
+    CachedFunc cfunc = compiler_->Lower(key, module_name_);
+    ICHECK(cfunc.defined());
+
+    auto opt_compiler = func->GetAttr<String>(attr::kCompiler);
+
+    // Add some metadata on top of the *original function* and invoke the callback so it can
+    // be captured.
+    // TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT
+    Map<GlobalVar, tir::PrimFunc> prim_fns;
+    Array<GlobalVar> all_prim_fn_vars;
+    for (const auto& kv : cfunc->funcs->functions) {
+      if (opt_compiler) {
+        // We expect just the original func but with just the ExternalSymbol attribute signalling
+        // the function (will be) compiled externally.
+        ICHECK(kv.second.as<FunctionNode>())
+            << PrettyPrint(kv.first) << " must be bound to an (external) Function";
+      } else {
+        // We expect one or more PrimFuncs, one of which corresponds to 'the' lowered primitive
+        // (and the rest in support of that via tir::Calls).
+        ICHECK(kv.second.as<tir::PrimFuncNode>())
+            << PrettyPrint(kv.first) << " must be bound to a PrimFunc";
+        prim_fns.Set(kv.first, Downcast<tir::PrimFunc>(kv.second));
+        all_prim_fn_vars.push_back(kv.first);
       }
+    }
+    Function func_with_metadata = func;
+    func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", cfunc->prim_fn_var);
+    func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns);
+    func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, cfunc->target);
+    this->process_fn_(func_with_metadata);
+
+    auto call_lowered_attrs = make_object<CallLoweredAttrs>();
+
+    // Non-External Relay Function
+    // TODO(mbs): "reshape" cleanup.
+    if (!opt_compiler && func->HasNonzeroAttr(attr::kReshapeOnly)) {
+      call_lowered_attrs->metadata.Set(attr::kReshapeOnly, tvm::Integer(1));
+    }
 
-      call_lowered_attrs->metadata.Set("relay_attrs", func->attrs);
-      call_lowered_attrs->metadata.Set("all_prim_fn_vars", all_prim_fn_vars);
-
-      if (IsDynamic(func->ret_type)) {
-        // Also lower the dynamic shape function.
-        // Shape function keys use the underlying primitive function as their 'function',
-        // but the generic 'cpu' target as the target since all shape functions run
-        // on the host cpu irrespective of where the primitive runs.
-        // TODO(mbs): Cleanup target handling.
-        Target shape_target("llvm");
-        CCacheKey shape_key(func, shape_target);
-        CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key);
-        // Capture the shape function's global var and parameters 'states' in call
-        // annotations so calling convention can be recovered.
-        // TODO(mbs): Capture all this as part of a 'call into TIR' construct once available.
-        // The way the shape function calling convention is derived and passed to call sites
-        // via the 'parameter states' could be improved.
-        call_lowered_attrs->metadata.Set("prim_shape_fn_var", lowered_shape_func->prim_fn_var);
-        call_lowered_attrs->metadata.Set("prim_shape_fn_states",
-                                         lowered_shape_func->shape_func_param_states);
-        call_lowered_attrs->metadata.Set(
-            "prim_shape_fn_num_inputs",
-            Integer(static_cast<int>(lowered_shape_func->inputs.size())));
-        call_lowered_attrs->metadata.Set(
-            "prim_shape_fn_num_outputs",
-            Integer(static_cast<int>(lowered_shape_func->outputs.size())));
-        Array<GlobalVar> all_prim_shape_fn_vars;
-        for (auto prim_shape_fn : lowered_shape_func->funcs->functions) {
-          CHECK(prim_shape_fn.second.as<tir::PrimFuncNode>()) << "must be a prim fn";
-          all_prim_shape_fn_vars.push_back(prim_shape_fn.first);
-        }
-        call_lowered_attrs->metadata.Set("all_prim_shape_fn_vars", all_prim_shape_fn_vars);
+    call_lowered_attrs->metadata.Set("relay_attrs", func->attrs);
+    call_lowered_attrs->metadata.Set("all_prim_fn_vars", all_prim_fn_vars);
+
+    if (IsDynamic(func->ret_type)) {
+      // Also lower the companion dynamic shape function.
+      // Shape function keys use the underlying primitive function as their 'function',
+      // but the generic 'cpu' target as the target since all shape functions run
+      // on the host cpu irrespective of where the primitive runs.
+      CCacheKey shape_key(func, host_se_scope_->target);
+      CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key);
+
+      // Capture the shape function's global var and parameters 'states' in call
+      // annotations so calling convention can be recovered.
+      // TODO(mbs): Shape cleanup.
+      call_lowered_attrs->metadata.Set("prim_shape_fn_var", lowered_shape_func->prim_fn_var);
+      call_lowered_attrs->metadata.Set("prim_shape_fn_states",
+                                       lowered_shape_func->shape_func_param_states);
+      call_lowered_attrs->metadata.Set(
+          "prim_shape_fn_num_inputs", Integer(static_cast<int>(lowered_shape_func->inputs.size())));
+      call_lowered_attrs->metadata.Set(
+          "prim_shape_fn_num_outputs",
+          Integer(static_cast<int>(lowered_shape_func->outputs.size())));
+      Array<GlobalVar> all_prim_shape_fn_vars;
+      for (const auto& kv : lowered_shape_func->funcs->functions) {
+        CHECK(kv.second.as<tir::PrimFuncNode>()) << "must be a prim fn";
+        all_prim_shape_fn_vars.push_back(kv.first);
       }
-      return CallLowered(lowered_func->prim_fn_var, visited_args, Attrs(call_lowered_attrs),
-                         type_args, span);
+      call_lowered_attrs->metadata.Set("all_prim_shape_fn_vars", all_prim_shape_fn_vars);
     }
+
+    return CallLowered(cfunc->prim_fn_var, std::move(visited_args), Attrs(call_lowered_attrs),
+                       type_args, std::move(span));
   }
 
   std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value) final {
     Var new_var = Downcast<Var>(Mutate(var));
     Expr new_value = Mutate(value);
     BaseFunc prim_func = ResolveToPrimitive(new_value);
 
-    if (prim_func.defined() && !prim_func->IsInstance<tir::PrimFuncNode>()) {
-      // Remember let var is bound to (possibly indirectly) a non-tir primitive.
-      Function func = Downcast<Function>(prim_func);
-      primitive_functions_.emplace(var, func);
+    if (prim_func.defined()) {
+      // Remember let var is bound (possibly indirectly) to a primitive function.
+      primitive_functions_.emplace(var.get(), prim_func);
     }
     return {new_var, new_value};
   }
 
   Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node) final {
     BaseFunc prim_func = ResolveToPrimitive(post_let_node->value);
-    if (prim_func.defined() && !prim_func->IsInstance<tir::PrimFuncNode>()) {
+    if (prim_func.defined()) {
       // Leaving let var scope
-      primitive_functions_.erase(pre_let_node->var);
+      primitive_functions_.erase(pre_let_node->var.get());
     }
     return DeviceAwareExprMutator::PostVisitLet_(pre_let_node, post_let_node);
   }
 
   Expr DeviceAwareVisitExpr_(const FunctionNode* function_node) override {
-    if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
-      // Nothing to lower inside primitive functions.
+    if (function_node->HasNonzeroAttr(attr::kPrimitive) ||
+        function_node->GetAttr<String>(attr::kExternalSymbol)) {
+      // Nothing to lower inside primitive/external functions.
       return GetRef<Function>(function_node);
     } else {
       return DeviceAwareExprMutator::DeviceAwareVisitExpr_(function_node);
     }
   }
 
   Expr DeviceAwareVisitExpr_(const CallNode* call_node) override {
-    // Passes before lowering might insert a call_lowered to call a function that has already
-    // been lowered. Therefore we might see call_lowered ops here, but we don't need to do anything
-    // because ResolveToPrimitive returns null for all calls where the call_node->op is an OpNode
-    Call call = GetRef<Call>(call_node);
-
-    // Look for (indirect) calls to primitives.
-    BaseFunc prim_func = ResolveToPrimitive(call_node->op);
-    if (!prim_func.defined()) {
-      // Not a call_node to a primitive function.
-      if (const FunctionNode* fn = call_node->op.as<FunctionNode>()) {
-        this->process_fn_(GetRef<Function>(fn));
+    // We can see five forms of calls:
+    //  1. A 'normal' Relay call to a Function with the "primitive" attribute. We will need
+    //     to lower that to a global PrimFunc and rewrite the call to:
+    //       call_lowered(@new_global, (arg1, ..., argn), <attributes>)
+    //     However there are a few special forms which are excluded from this treatment, see
+    //     below.
+    //  2. A 'normal' Relay call to a Function with the "compiler" attribute. We will need
+    //     to invoke the appropriate BYOC toolchain function to yield a runtime module and
+    //     rewrite the call to the same form as above.
+    //  3. A 'normal' Relay call to a PrimFunc which has already been supplied via a global
+    //     definition. We rewrite to use the call_lowered form, but otherwise nothing else
+    //     needs to be done.
+    //  4. A 'normal' Relay call to a Relay Function without any special attribute. These
+    //     calls are not changed.
+    //  5. A call_lowered call from an earlier invocation of this pass.
+    // Note that ResolveToPrimitive will yield non-null only for cases 1-3.
+
+    // Look for (possibly indirect) calls to primitives.
+    BaseFunc primitive_func = ResolveToPrimitive(call_node->op);
+    if (!primitive_func.defined()) {
+      // Not a call to a primitive function we need to rewrite.
+      if (const auto* function_node = call_node->op.as<FunctionNode>()) {
+        process_fn_(GetRef<Function>(function_node));
       }
-      return ExprMutator::VisitExpr_(call_node);
+      return DeviceAwareExprMutator::DeviceAwareVisitExpr_(call_node);
     }
 
-    // Similarly transform arguments.
-    Array<Expr> visited_args;
+    // Prepare the arguments.
+    Array<Expr> new_args;
     for (const auto& arg : call_node->args) {
-      visited_args.push_back(VisitExpr(arg));
+      new_args.push_back(VisitExpr(arg));
     }
 
-    // Already lowered by other means so we don't need to mutate
+    // Special case: device_copies are left as calls to primitive operators
+    // (thus undoing FuseOps) so that each backend can handle them directly.
+    // TODO(mbs): device_copy cleanup. Would be better for FuseOps to just leave device_copy alone.
+    if (const auto* function_node = primitive_func.as<FunctionNode>()) {
+      DeviceCopyProps device_copy_props = GetDeviceCopyProps(function_node->body);
+      if (device_copy_props.body.defined()) {
+        ICHECK_EQ(new_args.size(), 1);
+        return DeviceCopy(new_args[0], device_copy_props.src_se_scope,
+                          device_copy_props.dst_se_scope);
+      }
+    }
+
+    // Special case: If already lowered by other means then so we don't need to mutate
     // the call but we do need to mutate the arguments
-    if (prim_func->IsInstance<tir::PrimFuncNode>()) {
+    if (const auto* prim_func_node = primitive_func.as<tir::PrimFuncNode>()) {
       // Function should already be Target annotated by this point
       // but the TE Compiler metadata is still needed for the callback
       // TODO(Mousius) - Robustify this to not assume we're in the GlobalVar for Target Hooks
       GlobalVar prim_func_var = Downcast<GlobalVar>(call_node->op);
-      tir::PrimFunc downcast_prim_func = Downcast<tir::PrimFunc>(prim_func);
+      tir::PrimFunc prim_func = GetRef<tir::PrimFunc>(prim_func_node);
 
-      Map<GlobalVar, tir::PrimFunc> prim_fns = {{prim_func_var, downcast_prim_func}};
-      tir::PrimFunc func_with_metadata =
-          WithAttrs(downcast_prim_func, {
-                                            {"prim_fn_var", prim_func_var},
-                                            {"prim_funcs", prim_fns},
-                                        });
+      Map<GlobalVar, tir::PrimFunc> prim_fns = {{prim_func_var, prim_func}};
+      tir::PrimFunc func_with_metadata = WithAttrs(prim_func, {
+                                                                  {"prim_fn_var", prim_func_var},
+                                                                  {"prim_funcs", prim_fns},
+                                                              });
+
+      ICHECK(!IsDynamic(call_node->checked_type()));
+      auto call_lowered_attrs = make_object<CallLoweredAttrs>();
+      call_lowered_attrs->metadata.Set("relay_attrs", primitive_func->attrs);
 
-      this->process_fn_(func_with_metadata);
-      return Call(call_node->op, visited_args, call_node->attrs);
+      process_fn_(func_with_metadata);
+      return CallLowered(call_node->op, std::move(new_args), Attrs(std::move(call_lowered_attrs)),
+                         call_node->type_args, call_node->span);
     }
 
+    // Typical case: call to fused primitive Relay Function.
     // Find the desired target device.
     Target target;
-    if (prim_func->GetAttr<String>(attr::kCompiler).defined()) {
+    if (primitive_func->GetAttr<String>(attr::kCompiler).defined()) {
       // The generic 'external device' target.
+      // TODO(mbs): Retire once replaced unified BYOC compiler and target macihnery.

Review comment:
       typo: macihnery -> machinery




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbs-octoml commented on a change in pull request #9483: [Relay] Switch the VM to use the LowerTE pass instead of TECompiler::{Lower,LowerShapeFunc}.

Posted by GitBox <gi...@apache.org>.
mbs-octoml commented on a change in pull request #9483:
URL: https://github.com/apache/tvm/pull/9483#discussion_r759751938



##########
File path: src/relay/backend/vm/compiler.cc
##########
@@ -481,50 +484,12 @@ class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
     this->last_register_ = merge_register;
   }
 
-  void EmitShapeFunc(Function func, Array<Expr> inputs, Array<Expr> outputs) {
-    // Lower shape function
-    tec::CCacheKey key(func, host_se_scope_->target);
-    auto cfunc = context_->compiler->LowerShapeFunc(key);
-    int op_index = -1;
-    // pick the only function inside the context
-    ICHECK_EQ(cfunc->funcs->functions.size(), 1);
-    auto pfunc = Downcast<tir::PrimFunc>((*cfunc->funcs->functions.begin()).second);
-    if (context_->seen_funcs.count(pfunc) == 0) {
-      op_index = context_->cached_funcs.size();
-      context_->cached_funcs.push_back(cfunc);
-      context_->seen_funcs[pfunc] = op_index;
-    } else {
-      op_index = context_->seen_funcs[pfunc];
-    }
-
-    // Prepare input and output registers
+  void EmitInvokeTVMOp(const Expr& func, const Expr& inputs, const Expr& outputs,

Review comment:
       Because it was only needed to figure out which Target to use for the function-at-a-time lowering, but that has all been done already by the LowerTEPass.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch merged pull request #9483: [Relay] Switch the VM to use the LowerTE pass instead of TECompiler::{Lower,LowerShapeFunc}.

Posted by GitBox <gi...@apache.org>.
jroesch merged pull request #9483:
URL: https://github.com/apache/tvm/pull/9483


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch commented on a change in pull request #9483: [Relay] Switch the VM to use the LowerTE pass instead of TECompiler::{Lower,LowerShapeFunc}.

Posted by GitBox <gi...@apache.org>.
jroesch commented on a change in pull request #9483:
URL: https://github.com/apache/tvm/pull/9483#discussion_r758733500



##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -134,41 +144,78 @@ class TECompilerImpl : public TECompilerNode {
     return mod;
   }
 
+  void AddExterns(IRModule module) {

Review comment:
       The "goal" here is we could eventually delete this and just leave ExternDefinitions in the module?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch commented on a change in pull request #9483: [Relay] Switch the VM to use the LowerTE pass instead of TECompiler::{Lower,LowerShapeFunc}.

Posted by GitBox <gi...@apache.org>.
jroesch commented on a change in pull request #9483:
URL: https://github.com/apache/tvm/pull/9483#discussion_r758751452



##########
File path: src/relay/transforms/memory_alloc.cc
##########
@@ -311,48 +292,58 @@ class DialectRewriter : public transform::DeviceAwareExprMutator {
           shape_func_ins.push_back(scope->Push(in_shape_var, sh_of));
           input_pos++;
         }
-        is_inputs.push_back(0);
       } else if (state == tec::kNeedInputData) {
         auto new_arg = Mutate(arg);  // already accounts for device
         SEScope arg_se_scope = GetSEScope(arg);
+        // The dynamic shape function is expecting its data on the host/CPU, so insert a
+        // device_copy otherwise. (We'll need to fuse & lower these copies in the same way
+        // we fuse & lower other operators we insert for, eg, dynamic tensor size calculation.)
         if (arg_se_scope != host_se_scope_) {
           new_arg = OnDevice(DeviceCopy(new_arg, arg_se_scope, host_se_scope_), host_se_scope_,
                              /*is_fixed=*/true);
         }
         Var in_shape_var("in_shape_" + std::to_string(input_pos), Type(nullptr));
         shape_func_ins.push_back(scope->Push(in_shape_var, new_arg));
         input_pos++;
-        is_inputs.push_back(1);
       } else {
         // TODO(@jroesch): handle kNeedBoth
         LOG(FATAL) << "unsupported shape function input state";
       }
     }
+    ICHECK_EQ(shape_func_ins.size(), func_type_node->arg_types.size());
+
+    // Establish the result shapes.
+    const auto* res_tuple_node = func_type_node->ret_type.as<TupleTypeNode>();
+    ICHECK(res_tuple_node);
 
     Array<Expr> out_shapes;
-    for (size_t i = 0; i < cfunc->outputs.size(); ++i) {
-      auto out = cfunc->outputs[i];
-      auto tt = TensorType(out->shape, out->dtype);
-      // Put shape func on CPU. This also ensures that everything between
-      // shape_of and shape_func are on CPU.
-      auto alloc = OnDevice(MakeStaticAllocation(scope, tt, host_se_scope_, std::to_string(i)),
-                            host_se_scope_, /*is_fixed=*/true);
+    for (size_t i = 0; i < res_tuple_node->fields.size(); ++i) {
+      const auto* tensor_type_node = res_tuple_node->fields[i].as<TensorTypeNode>();
+      ICHECK(tensor_type_node);
+      // Put the shape func on the host. This also ensures that everything between
+      // shape_of and shape_func is similarly on the host.
+      Expr alloc = MakeStaticAllocation(scope, GetRef<TensorType>(tensor_type_node), host_se_scope_,
+                                        std::to_string(i));
+      // TODO(mbs): Why extra var binding?

Review comment:
       I think the original pass was incrementally re-anf-ing the code.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbs-octoml commented on a change in pull request #9483: [Relay] Switch the VM to use the LowerTE pass instead of TECompiler::{Lower,LowerShapeFunc}.

Posted by GitBox <gi...@apache.org>.
mbs-octoml commented on a change in pull request #9483:
URL: https://github.com/apache/tvm/pull/9483#discussion_r759561480



##########
File path: src/relay/transforms/inline.cc
##########
@@ -54,22 +56,36 @@ class Inliner : ExprMutator {
       : cur_node_(cur_node), call_graph_(call_graph) {}
 
   Expr VisitExpr_(const CallNode* call_node) final {
-    Expr op = call_node->op;
-    const auto* gvn = op.as<GlobalVarNode>();
+    // We can work with calls in both Relay and call_lowered form.
+    Array<Expr> args;
+    Expr op;

Review comment:
       Introduced 'GetAnyCall' helper since this pattern was starting to multiply. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbs-octoml commented on a change in pull request #9483: [Relay] Switch the VM to use the LowerTE pass instead of TECompiler::{Lower,LowerShapeFunc}.

Posted by GitBox <gi...@apache.org>.
mbs-octoml commented on a change in pull request #9483:
URL: https://github.com/apache/tvm/pull/9483#discussion_r759585862



##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -134,41 +144,78 @@ class TECompilerImpl : public TECompilerNode {
     return mod;
   }
 
+  void AddExterns(IRModule module) {
+    // Everything tagged with "Compiler" has been compiled, so remove those definitions.
+    std::vector<GlobalVar> to_be_deleted;
+    for (const auto& kv : module->functions) {
+      if (kv.second->GetAttr<String>(attr::kCompiler).defined()) {
+        to_be_deleted.push_back(kv.first);
+      }
+    }
+    for (const auto& global_var : to_be_deleted) {
+      module->Remove(global_var);
+    }
+    // HOWEVER we still need a Relay definition to go with those now external functions, so
+    // retrieve them from the cache and mark them with "ExternalSymbol".
+    for (const auto& kv1 : cache_) {
+      auto src_func = kv1.first->source_func;
+      ICHECK(src_func.defined());
+      if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
+        for (const auto& kv2 : kv1.second->cached_func->funcs->functions) {
+          if (const auto* function_node = kv2.second.as<FunctionNode>()) {
+            // Abandon the existing function annotations.
+            Function function(function_node->params, function_node->body, function_node->ret_type,
+                              function_node->type_params, /*attrs=*/{}, function_node->span);

Review comment:
       This is a classic 'double lifting confusion'.
   
               Optional<DictAttrs> none = {};
               Optional<DictAttrs> empty= DictAttrs();
   
   The  caller will have to be careful about which they mean...




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbs-octoml commented on a change in pull request #9483: [Relay] Switch the VM to use the LowerTE pass instead of TECompiler::{Lower,LowerShapeFunc}.

Posted by GitBox <gi...@apache.org>.
mbs-octoml commented on a change in pull request #9483:
URL: https://github.com/apache/tvm/pull/9483#discussion_r759580953



##########
File path: src/relay/backend/aot_executor_codegen.cc
##########
@@ -722,18 +725,21 @@ class AOTExecutorCodegen : public MixedModeVisitor {
     for (const auto& kv : targets_) {
       VLOG(1) << "target: " << kv.second->ToDebugString();
     }
-    if (target_host_.defined()) {
-      VLOG(1) << "target host: " << target_host_->ToDebugString();
-    }
+    ICHECK(target_host_.defined());

Review comment:
       done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbs-octoml commented on a change in pull request #9483: [Relay] Switch the VM to use the LowerTE pass instead of TECompiler::{Lower,LowerShapeFunc}.

Posted by GitBox <gi...@apache.org>.
mbs-octoml commented on a change in pull request #9483:
URL: https://github.com/apache/tvm/pull/9483#discussion_r759753045



##########
File path: src/relay/transforms/type_infer.cc
##########
@@ -395,7 +400,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
   //
   // The result will be the return type of the operator.
   Type PrimitiveCall(const FuncTypeNode* op, Array<Type> arg_types, const Attrs& attrs,
-                     const Span& span) {
+                     const Span& span, const Expr& expr) {

Review comment:
       Good catch -- I needed it to debug a failure at some point but forgot to back it out again.
   
   BTW we very much need spans everywhere to help debugging.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] electriclilies commented on a change in pull request #9483: [Relay] Switch the VM to use the LowerTE pass instead of TECompiler::{Lower,LowerShapeFunc}.

Posted by GitBox <gi...@apache.org>.
electriclilies commented on a change in pull request #9483:
URL: https://github.com/apache/tvm/pull/9483#discussion_r758844510



##########
File path: include/tvm/runtime/vm/executable.h
##########
@@ -133,21 +133,20 @@ class Executable : public ModuleNode {
 
   /*!
    * \brief Returns a description of all the constants in the executable in human-readable
-   * format. Not intended to be machine readable, but rather to help with debugging and
-   * diffing generated code.
+   * format. Intended for debugging and diff-testing.

Review comment:
       did this get changed? 

##########
File path: src/relay/backend/aot_executor_codegen.cc
##########
@@ -722,18 +725,21 @@ class AOTExecutorCodegen : public MixedModeVisitor {
     for (const auto& kv : targets_) {
       VLOG(1) << "target: " << kv.second->ToDebugString();
     }
-    if (target_host_.defined()) {
-      VLOG(1) << "target host: " << target_host_->ToDebugString();
-    }
+    ICHECK(target_host_.defined());

Review comment:
       Can you add a message here?

##########
File path: src/relay/backend/aot_executor_codegen.cc
##########
@@ -314,9 +316,9 @@ class AOTExecutorCodegen : public MixedModeVisitor {
   /*!
    * brief Create a function call
    * \param call_lowered_props The lowered function and the arguments to call it with
-   * \param call The call we got func and args from
+   * \param call_node The call we got func and args from
    */
-  void CreateFuncCall(CallLoweredProps call_lowered_props, Call call) {
+  void CreateFuncCall(CallLoweredProps call_lowered_props, const Expr& result_expr) {

Review comment:
       Change documentation to match parameter name. Also since we only use result_expr for the PackSid thingy it would be nice to add a note about why we need to pass it in in the documentation.
   
   (also note: I think CI should have caught this in sanity check?)

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -435,42 +497,56 @@ using AnalysisRemapping = std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual
  */
 class LowerTensorExprMutator : public DeviceAwareExprMutator {
  public:
-  LowerTensorExprMutator(const IRModule& module, ProcessFn process_fn, const String& module_name,
-                         TECompiler compiler)
+  LowerTensorExprMutator(const IRModule& module, ProcessFn process_fn, String module_name,
+                         TECompiler compiler, SEScope host_se_scope)

Review comment:
       [placemarker]

##########
File path: src/relay/backend/graph_executor_codegen.cc
##########
@@ -215,24 +215,35 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
 
     if (memory_plan_.defined()) {
       // TODO(@electriclilies, @jroesch): remove UpdateMainWorkspaceSize
-      func_info =
-          relay::tec::UpdateMainWorkspaceSize(mod, targets_, memory_plan_->expr_to_storage_info);
+      tec::TargetMap tec_target_map;
+      for (const auto& pair : targets_) {
+        tec_target_map.emplace(static_cast<DLDeviceType>(pair.first->value), pair.second);
+      }

Review comment:
       Is this copy just to make sure UpdateMainWorkspaceSize doesn't have access to targets_?

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -134,41 +144,78 @@ class TECompilerImpl : public TECompilerNode {
     return mod;
   }
 
+  void AddExterns(IRModule module) {
+    // Everything tagged with "Compiler" has been compiled, so remove those definitions.
+    std::vector<GlobalVar> to_be_deleted;
+    for (const auto& kv : module->functions) {
+      if (kv.second->GetAttr<String>(attr::kCompiler).defined()) {
+        to_be_deleted.push_back(kv.first);
+      }
+    }
+    for (const auto& global_var : to_be_deleted) {
+      module->Remove(global_var);
+    }
+    // HOWEVER we still need a Relay definition to go with those now external functions, so
+    // retrieve them from the cache and mark them with "ExternalSymbol".
+    for (const auto& kv1 : cache_) {
+      auto src_func = kv1.first->source_func;
+      ICHECK(src_func.defined());
+      if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
+        for (const auto& kv2 : kv1.second->cached_func->funcs->functions) {
+          if (const auto* function_node = kv2.second.as<FunctionNode>()) {
+            // Abandon the existing function annotations.
+            Function function(function_node->params, function_node->body, function_node->ret_type,
+                              function_node->type_params, /*attrs=*/{}, function_node->span);
+            // Mark function as 'extern' using the "ExternalSymbol" attribute.
+            function = WithAttr(std::move(function), attr::kExternalSymbol, kv2.first->name_hint);
+            module->Add(kv2.first, function);
+          }
+        }
+      }
+    }
+  }
+
   Array<tvm::runtime::Module> LowerExternalFunctions() {
     Array<tvm::runtime::Module> ret;
-    std::unordered_map<std::string, std::string> cached_symbol;
     std::vector<CCacheKey> cached_ext_funcs;
 
     for (const auto& it : cache_) {
       auto src_func = it.first->source_func;
       ICHECK(src_func.defined());
-      if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
-        auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
-        std::string code_gen_name = code_gen.value();
+      Optional<String> opt_compiler = src_func->GetAttr<String>(attr::kCompiler);
+      if (opt_compiler.defined()) {
+        Optional<String> opt_symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+        ICHECK(opt_symbol_name.defined()) << "No external symbol is set for:" << std::endl
+                                          << PrettyPrint(src_func);
+        VLOG(1) << "using external codegen '" << opt_compiler.value() << "' for name '"
+                << opt_symbol_name.value() << "' and function:" << std::endl
+                << PrettyPrint(src_func);
         cached_ext_funcs.push_back(it.first);
 
-        auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
-        ICHECK(symbol_name.defined()) << "No external symbol is set for:\n"
-                                      << AsText(src_func, false);
-
-        std::string sn = symbol_name.value();
-        if (cached_symbol.count(sn)) {
-          cached_symbol[sn] = code_gen_name;
-        } else {
-          ICHECK_NE(sn, code_gen_name)
-              << "Found duplicated symbol: " << sn << " for: " << code_gen_name;
-        }
-
-        std::string ext_name = "relay.ext." + code_gen_name;
+        std::string ext_name = "relay.ext." + opt_compiler.value();
         auto pf = tvm::runtime::Registry::Get(ext_name);
         ICHECK(pf) << "Failed to find the codegen tool for " << ext_name;
         // No need to keep compiler attribute at this point, functions have been
         // extracted for specific codegen.
         src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue<ObjectRef>());
+        VLOG_CONTEXT << ext_name;
         runtime::Module ext_mod = (*pf)(src_func);
-
-        ICHECK(ext_mod.defined()) << "No external runtime is generated.";
-        ret.push_back(ext_mod);
+        if (ext_mod.defined()) {
+          if (ext_mod->GetFunction(opt_symbol_name.value(), /*query_imports=*/true) == nullptr) {
+            // It's possible the codegen yielded C or C++ tracked separately and thus the
+            // returned runtime module can be empty.

Review comment:
       Is there a way to check that the the codegen made C/C++ that is tracked separately? Would be nice to have a check for that here. Otherwise, maybe add some of the comment to the VLOG message?

##########
File path: include/tvm/relay/function.h
##########
@@ -148,6 +148,19 @@ Function WithFields(Function function, Optional<Array<Var>> opt_params = Optiona
                     Optional<DictAttrs> opt_attrs = Optional<DictAttrs>(),
                     Optional<Span> opt_span = Optional<Span>());
 
+/*
+ * \brief Returns the Relay FunctionNode represented by base_func if it should be optimized.
+ * Otherwise returns nullptr.
+ *  - PrimFuncs are obviously not Relay Functions.

Review comment:
       To clarify, if the BaseFunc is a PrimFunc then it returns nullptr?

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -134,41 +144,78 @@ class TECompilerImpl : public TECompilerNode {
     return mod;
   }
 
+  void AddExterns(IRModule module) {
+    // Everything tagged with "Compiler" has been compiled, so remove those definitions.
+    std::vector<GlobalVar> to_be_deleted;
+    for (const auto& kv : module->functions) {
+      if (kv.second->GetAttr<String>(attr::kCompiler).defined()) {
+        to_be_deleted.push_back(kv.first);
+      }
+    }
+    for (const auto& global_var : to_be_deleted) {
+      module->Remove(global_var);
+    }
+    // HOWEVER we still need a Relay definition to go with those now external functions, so
+    // retrieve them from the cache and mark them with "ExternalSymbol".

Review comment:
       Why do you retrieve from the cache instead of using the to_be_deleted list you just constructed?

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -134,41 +144,78 @@ class TECompilerImpl : public TECompilerNode {
     return mod;
   }
 
+  void AddExterns(IRModule module) {
+    // Everything tagged with "Compiler" has been compiled, so remove those definitions.
+    std::vector<GlobalVar> to_be_deleted;
+    for (const auto& kv : module->functions) {
+      if (kv.second->GetAttr<String>(attr::kCompiler).defined()) {
+        to_be_deleted.push_back(kv.first);
+      }
+    }
+    for (const auto& global_var : to_be_deleted) {
+      module->Remove(global_var);
+    }
+    // HOWEVER we still need a Relay definition to go with those now external functions, so
+    // retrieve them from the cache and mark them with "ExternalSymbol".
+    for (const auto& kv1 : cache_) {
+      auto src_func = kv1.first->source_func;
+      ICHECK(src_func.defined());
+      if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
+        for (const auto& kv2 : kv1.second->cached_func->funcs->functions) {
+          if (const auto* function_node = kv2.second.as<FunctionNode>()) {
+            // Abandon the existing function annotations.
+            Function function(function_node->params, function_node->body, function_node->ret_type,
+                              function_node->type_params, /*attrs=*/{}, function_node->span);

Review comment:
       Hmm, just occurred to me that WithFields won't erase the attrs here if you pass {} in as attrs. I should add a note about that in the WithFields documentation

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -134,41 +144,78 @@ class TECompilerImpl : public TECompilerNode {
     return mod;
   }
 
+  void AddExterns(IRModule module) {
+    // Everything tagged with "Compiler" has been compiled, so remove those definitions.
+    std::vector<GlobalVar> to_be_deleted;
+    for (const auto& kv : module->functions) {
+      if (kv.second->GetAttr<String>(attr::kCompiler).defined()) {
+        to_be_deleted.push_back(kv.first);
+      }
+    }
+    for (const auto& global_var : to_be_deleted) {
+      module->Remove(global_var);
+    }
+    // HOWEVER we still need a Relay definition to go with those now external functions, so
+    // retrieve them from the cache and mark them with "ExternalSymbol".
+    for (const auto& kv1 : cache_) {
+      auto src_func = kv1.first->source_func;
+      ICHECK(src_func.defined());
+      if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
+        for (const auto& kv2 : kv1.second->cached_func->funcs->functions) {
+          if (const auto* function_node = kv2.second.as<FunctionNode>()) {
+            // Abandon the existing function annotations.
+            Function function(function_node->params, function_node->body, function_node->ret_type,
+                              function_node->type_params, /*attrs=*/{}, function_node->span);
+            // Mark function as 'extern' using the "ExternalSymbol" attribute.
+            function = WithAttr(std::move(function), attr::kExternalSymbol, kv2.first->name_hint);
+            module->Add(kv2.first, function);
+          }
+        }
+      }
+    }
+  }
+
   Array<tvm::runtime::Module> LowerExternalFunctions() {
     Array<tvm::runtime::Module> ret;
-    std::unordered_map<std::string, std::string> cached_symbol;
     std::vector<CCacheKey> cached_ext_funcs;
 
     for (const auto& it : cache_) {
       auto src_func = it.first->source_func;
       ICHECK(src_func.defined());
-      if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
-        auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
-        std::string code_gen_name = code_gen.value();
+      Optional<String> opt_compiler = src_func->GetAttr<String>(attr::kCompiler);
+      if (opt_compiler.defined()) {
+        Optional<String> opt_symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+        ICHECK(opt_symbol_name.defined()) << "No external symbol is set for:" << std::endl
+                                          << PrettyPrint(src_func);
+        VLOG(1) << "using external codegen '" << opt_compiler.value() << "' for name '"
+                << opt_symbol_name.value() << "' and function:" << std::endl
+                << PrettyPrint(src_func);
         cached_ext_funcs.push_back(it.first);
 
-        auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
-        ICHECK(symbol_name.defined()) << "No external symbol is set for:\n"
-                                      << AsText(src_func, false);
-
-        std::string sn = symbol_name.value();
-        if (cached_symbol.count(sn)) {
-          cached_symbol[sn] = code_gen_name;
-        } else {
-          ICHECK_NE(sn, code_gen_name)
-              << "Found duplicated symbol: " << sn << " for: " << code_gen_name;
-        }
-
-        std::string ext_name = "relay.ext." + code_gen_name;
+        std::string ext_name = "relay.ext." + opt_compiler.value();
         auto pf = tvm::runtime::Registry::Get(ext_name);
         ICHECK(pf) << "Failed to find the codegen tool for " << ext_name;
         // No need to keep compiler attribute at this point, functions have been
         // extracted for specific codegen.
         src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue<ObjectRef>());
+        VLOG_CONTEXT << ext_name;
         runtime::Module ext_mod = (*pf)(src_func);
-
-        ICHECK(ext_mod.defined()) << "No external runtime is generated.";
-        ret.push_back(ext_mod);
+        if (ext_mod.defined()) {
+          if (ext_mod->GetFunction(opt_symbol_name.value(), /*query_imports=*/true) == nullptr) {
+            // It's possible the codegen yielded C or C++ tracked separately and thus the
+            // returned runtime module can be empty.

Review comment:
       Or maybe log a warning.

##########
File path: src/relay/analysis/call_graph.cc
##########
@@ -64,9 +67,19 @@ void CallGraphNode::AddToCallGraph(const GlobalVar& gv, const Function& func) {
   // post-order visitor will visit each AST node of the current function to
   // figure out the dependencies between functions.
   PostOrderVisit(func, [&](const Expr& expr) {
-    if (const GlobalVarNode* gvn = expr.as<GlobalVarNode>()) {
-      auto callee = GetRef<GlobalVar>(gvn);
-      cg_node->AddCalledGlobal(LookupGlobalVar(callee));
+    // TODO(mbs): Cleanup shapes functions.
+    if (const auto* call_node = expr.as<CallNode>()) {
+      CallLoweredProps props = GetCallLoweredProps(call_node);
+      if (props.lowered_func.defined() && props.attrs.metadata.count("prim_shape_fn_var")) {

Review comment:
       Might be nice to add a shape_func field to CallLoweredProps, that way we don't need to extract prim_shape_fn_var from the props.attrs directly. 
   
   Then, this check would be like
   ```if (props.lowered_func.defined() && props.shape_func.defined())```

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -224,23 +271,28 @@ class TECompilerImpl : public TECompilerNode {
     }
     cur_ccache_key_ = key;
 
-    // No need to lower external functions for now. We will invoke the external
-    // codegen tool once and lower all functions together.
-    if (key->source_func->GetAttr<String>(attr::kCompiler).defined()) {
-      auto ir_module = IRModule();
-      const auto name_node = key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
-      ICHECK(name_node.defined()) << "External function has not been attached a name yet.";
-      auto func_name = GetUniqueName(name_node.value(), &name_map_);
+    Optional<String> opt_compiler = key->source_func->GetAttr<String>(attr::kCompiler);
+    if (opt_compiler.defined()) {
+      // Don't compile now since we don't have anywhere to put the resulting runtime module.
+      // Instead place the original definition in the cache and wait for LowerExternalFunctions.
+      IRModule ir_module;
+      // Note that the source_func may already be bound to a global function in the module

Review comment:
       In the code you ICHECK that opt_global_symbol is defined, so I'd suggest changing "may" to "should be" or "is"




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] electriclilies commented on a change in pull request #9483: [Relay] Switch the VM to use the LowerTE pass instead of TECompiler::{Lower,LowerShapeFunc}.

Posted by GitBox <gi...@apache.org>.
electriclilies commented on a change in pull request #9483:
URL: https://github.com/apache/tvm/pull/9483#discussion_r758868700



##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -435,42 +497,56 @@ using AnalysisRemapping = std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual
  */
 class LowerTensorExprMutator : public DeviceAwareExprMutator {
  public:
-  LowerTensorExprMutator(const IRModule& module, ProcessFn process_fn, const String& module_name,
-                         TECompiler compiler)
+  LowerTensorExprMutator(const IRModule& module, ProcessFn process_fn, String module_name,
+                         TECompiler compiler, SEScope host_se_scope)

Review comment:
       [placemarker] this is where i stopped reading, please disregard!




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbs-octoml commented on a change in pull request #9483: [Relay] Switch the VM to use the LowerTE pass instead of TECompiler::{Lower,LowerShapeFunc}.

Posted by GitBox <gi...@apache.org>.
mbs-octoml commented on a change in pull request #9483:
URL: https://github.com/apache/tvm/pull/9483#discussion_r759550224



##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -134,41 +144,78 @@ class TECompilerImpl : public TECompilerNode {
     return mod;
   }
 
+  void AddExterns(IRModule module) {

Review comment:
       Forgot to comment this helper, sorry. Yeah, very much a stepping stone pending more official extern binding representation (prob new subclass of BaseFuncNode)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbs-octoml commented on a change in pull request #9483: [Relay] Switch the VM to use the LowerTE pass instead of TECompiler::{Lower,LowerShapeFunc}.

Posted by GitBox <gi...@apache.org>.
mbs-octoml commented on a change in pull request #9483:
URL: https://github.com/apache/tvm/pull/9483#discussion_r759579604



##########
File path: src/relay/backend/graph_executor_codegen.cc
##########
@@ -215,24 +215,35 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
 
     if (memory_plan_.defined()) {
       // TODO(@electriclilies, @jroesch): remove UpdateMainWorkspaceSize
-      func_info =
-          relay::tec::UpdateMainWorkspaceSize(mod, targets_, memory_plan_->expr_to_storage_info);
+      tec::TargetMap tec_target_map;
+      for (const auto& pair : targets_) {
+        tec_target_map.emplace(static_cast<DLDeviceType>(pair.first->value), pair.second);
+      }

Review comment:
       No, it's just a representation mismatch between Map<Integer,...> and unordered_map<DLDeviceType, ...>.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch commented on a change in pull request #9483: [Relay] Switch the VM to use the LowerTE pass instead of TECompiler::{Lower,LowerShapeFunc}.

Posted by GitBox <gi...@apache.org>.
jroesch commented on a change in pull request #9483:
URL: https://github.com/apache/tvm/pull/9483#discussion_r758750942



##########
File path: src/relay/transforms/memory_alloc.cc
##########
@@ -122,64 +107,81 @@ class DialectRewriter : public transform::DeviceAwareExprMutator {
     return ret;
   }
 
-  Expr DeviceAwareVisitExpr_(const CallNode* cn) final {
-    Call call = GetRef<Call>(cn);
+  Expr DeviceAwareVisitExpr_(const CallNode* call_node) final {
+    DeviceCopyProps device_copy_props = GetDeviceCopyProps(call_node);
+    CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
+
+    if (device_copy_props.body.defined()) {
+      // Special case: device_copy calls remain in their original (and functional) form.
+      // TODO(mbs): device_copy cleanup.
+      return transform::DeviceAwareExprMutator::DeviceAwareVisitExpr_(call_node);
+    }
+
+    if (!call_lowered_props.lowered_func.defined()) {
+      // This is a call to a user-defined Relay functinon, which will be handled directly by
+      // the VM and does not need conversion to DPS.
+      return transform::DeviceAwareExprMutator::DeviceAwareVisitExpr_(call_node);
+    }
+
+    Call call = GetRef<Call>(call_node);
+    VLOG(1) << "converting lowered call to DPS:" << std::endl << PrettyPrint(call);
+
     SEScope se_scope = GetSEScope(call);
-    if (IsPrimitive(cn)) {
-      // Because we are in ANF we do not need to visit the arguments.
-      // TODO(mbs): But does so anyway...
-      LetList& scope = scopes_.back();
-      std::vector<Expr> new_args;
-      for (const auto& it : cn->args) {
-        new_args.push_back(Mutate(it));
-      }
+    LetList& scope = scopes_.back();
 
-      Tuple ins(new_args);
-      Type ret_type = cn->checked_type_;
-      std::vector<TensorType> out_types = FlattenTupleType(ret_type);
+    std::vector<Expr> new_args;
+    for (const auto& arg : call_lowered_props.arguments) {
+      new_args.push_back(Mutate(arg));
+    }
+    Tuple ins(new_args);
+    Type ret_type = call_node->checked_type_;
+    std::vector<TensorType> out_types = FlattenTupleType(ret_type);
+
+    // Handle reshape.

Review comment:
       Nevermind I see now.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch commented on a change in pull request #9483: [Relay] Switch the VM to use the LowerTE pass instead of TECompiler::{Lower,LowerShapeFunc}.

Posted by GitBox <gi...@apache.org>.
jroesch commented on a change in pull request #9483:
URL: https://github.com/apache/tvm/pull/9483#discussion_r758750736



##########
File path: src/relay/transforms/memory_alloc.cc
##########
@@ -59,26 +59,11 @@ using namespace tvm::runtime;
 namespace tvm {
 namespace relay {
 
-// Check if the primitive function contains only reshape ops.
-bool IsReshapeOnly(const Expr& expr) {

Review comment:
       Did we break this optimization?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbs-octoml commented on a change in pull request #9483: [Relay] Switch the VM to use the LowerTE pass instead of TECompiler::{Lower,LowerShapeFunc}.

Posted by GitBox <gi...@apache.org>.
mbs-octoml commented on a change in pull request #9483:
URL: https://github.com/apache/tvm/pull/9483#discussion_r759587798



##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -134,41 +144,78 @@ class TECompilerImpl : public TECompilerNode {
     return mod;
   }
 
+  void AddExterns(IRModule module) {
+    // Everything tagged with "Compiler" has been compiled, so remove those definitions.
+    std::vector<GlobalVar> to_be_deleted;
+    for (const auto& kv : module->functions) {
+      if (kv.second->GetAttr<String>(attr::kCompiler).defined()) {
+        to_be_deleted.push_back(kv.first);
+      }
+    }
+    for (const auto& global_var : to_be_deleted) {
+      module->Remove(global_var);
+    }
+    // HOWEVER we still need a Relay definition to go with those now external functions, so
+    // retrieve them from the cache and mark them with "ExternalSymbol".
+    for (const auto& kv1 : cache_) {
+      auto src_func = kv1.first->source_func;
+      ICHECK(src_func.defined());
+      if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
+        for (const auto& kv2 : kv1.second->cached_func->funcs->functions) {
+          if (const auto* function_node = kv2.second.as<FunctionNode>()) {
+            // Abandon the existing function annotations.
+            Function function(function_node->params, function_node->body, function_node->ret_type,
+                              function_node->type_params, /*attrs=*/{}, function_node->span);
+            // Mark function as 'extern' using the "ExternalSymbol" attribute.
+            function = WithAttr(std::move(function), attr::kExternalSymbol, kv2.first->name_hint);
+            module->Add(kv2.first, function);
+          }
+        }
+      }
+    }
+  }
+
   Array<tvm::runtime::Module> LowerExternalFunctions() {
     Array<tvm::runtime::Module> ret;
-    std::unordered_map<std::string, std::string> cached_symbol;
     std::vector<CCacheKey> cached_ext_funcs;
 
     for (const auto& it : cache_) {
       auto src_func = it.first->source_func;
       ICHECK(src_func.defined());
-      if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
-        auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
-        std::string code_gen_name = code_gen.value();
+      Optional<String> opt_compiler = src_func->GetAttr<String>(attr::kCompiler);
+      if (opt_compiler.defined()) {
+        Optional<String> opt_symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+        ICHECK(opt_symbol_name.defined()) << "No external symbol is set for:" << std::endl
+                                          << PrettyPrint(src_func);
+        VLOG(1) << "using external codegen '" << opt_compiler.value() << "' for name '"
+                << opt_symbol_name.value() << "' and function:" << std::endl
+                << PrettyPrint(src_func);
         cached_ext_funcs.push_back(it.first);
 
-        auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
-        ICHECK(symbol_name.defined()) << "No external symbol is set for:\n"
-                                      << AsText(src_func, false);
-
-        std::string sn = symbol_name.value();
-        if (cached_symbol.count(sn)) {
-          cached_symbol[sn] = code_gen_name;
-        } else {
-          ICHECK_NE(sn, code_gen_name)
-              << "Found duplicated symbol: " << sn << " for: " << code_gen_name;
-        }
-
-        std::string ext_name = "relay.ext." + code_gen_name;
+        std::string ext_name = "relay.ext." + opt_compiler.value();
         auto pf = tvm::runtime::Registry::Get(ext_name);
         ICHECK(pf) << "Failed to find the codegen tool for " << ext_name;
         // No need to keep compiler attribute at this point, functions have been
         // extracted for specific codegen.
         src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue<ObjectRef>());
+        VLOG_CONTEXT << ext_name;
         runtime::Module ext_mod = (*pf)(src_func);
-
-        ICHECK(ext_mod.defined()) << "No external runtime is generated.";
-        ret.push_back(ext_mod);
+        if (ext_mod.defined()) {
+          if (ext_mod->GetFunction(opt_symbol_name.value(), /*query_imports=*/true) == nullptr) {
+            // It's possible the codegen yielded C or C++ tracked separately and thus the
+            // returned runtime module can be empty.

Review comment:
       I asked our ARM friends about this since it was also bugging me. It turns out the runtime::Module will be an EthosUModule which does not support GetFunction for the compiled functions, but *does* support GetFunction for the meta 'get_func_names' which would yield a list of those names. However that's not a standard interface te_compiler can depend on. Better would be HasFunction or something on runtime::Module but I'm leaving that for a rain(ier) day.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbs-octoml commented on a change in pull request #9483: [Relay] Switch the VM to use the LowerTE pass instead of TECompiler::{Lower,LowerShapeFunc}.

Posted by GitBox <gi...@apache.org>.
mbs-octoml commented on a change in pull request #9483:
URL: https://github.com/apache/tvm/pull/9483#discussion_r759750418



##########
File path: src/relay/backend/vm/compiler.cc
##########
@@ -1051,6 +995,26 @@ transform::Sequential MemoryOpt(const SEScope& cpu_se_scope) {
   return transform::Sequential(std::move(pass_seqs));
 }
 
+transform::Sequential VMCompiler::LowerOperators(const SEScope& host_se_scope) {
+  Array<Pass> pass_seqs;

Review comment:
       done, thanks.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbs-octoml commented on pull request #9483: [Relay] Switch the VM to use the LowerTE pass instead of TECompiler::{Lower,LowerShapeFunc}.

Posted by GitBox <gi...@apache.org>.
mbs-octoml commented on pull request #9483:
URL: https://github.com/apache/tvm/pull/9483#issuecomment-983022283


   Thanks @electriclilies appreciate your review. Is there any chance of finishing today so I can kick off a ci run overnight? Keen to get this off my plate given the # changes,


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] electriclilies commented on a change in pull request #9483: [Relay] Switch the VM to use the LowerTE pass instead of TECompiler::{Lower,LowerShapeFunc}.

Posted by GitBox <gi...@apache.org>.
electriclilies commented on a change in pull request #9483:
URL: https://github.com/apache/tvm/pull/9483#discussion_r758868700



##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -435,42 +497,56 @@ using AnalysisRemapping = std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual
  */
 class LowerTensorExprMutator : public DeviceAwareExprMutator {
  public:
-  LowerTensorExprMutator(const IRModule& module, ProcessFn process_fn, const String& module_name,
-                         TECompiler compiler)
+  LowerTensorExprMutator(const IRModule& module, ProcessFn process_fn, String module_name,
+                         TECompiler compiler, SEScope host_se_scope)

Review comment:
       [placemarker] this is where i stopped reading, please disregard!




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org