You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/02/24 19:23:14 UTC

[GitHub] [tvm] csullivan opened a new pull request #7518: [WIP] Staged refactor and removal of compile engine

csullivan opened a new pull request #7518:
URL: https://github.com/apache/tvm/pull/7518


   This is an in progress removal of the use of compile engine. The motivation is to bring TIR compilation into the main flow of the compiler rather than producing and compiling it via a callback into the compile engine. By replacing Relay primitive function calls with TIR prim function calls that contain the lowered TIR we enable, 
   - An intermediate stage in the lowering process where Relay and TIR coexist. 
   - The ability to add passes at this intermediate stage, 
       - For example memory planning which can infer user provided information from TE and the resulting TIR. 
   
   We are starting with a proof of concept by refactoring the GraphRuntimeCodegen to use an introduced TIR/TE compiler instead of the compile engine directly. In the new flow, 
   - The TE/TIR compiler lowers TE in the LowerTensorExpr pass 
   - Replaces relay.Function(attr:primitive) with a PrimFnCall that contains the lowered TIR
   - Runs GraphPlanMemory
   - Finally runs GraphRuntimeCodegen::VisitExpr to lower to graph JSON
   
   We plan to post an RFC in the discuss forum for more discussion soon, but we welcome discussion, comments, concerns here as we push on this refactor.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch commented on a change in pull request #7518: [RFC]TECompiler: Staged refactor and removal of compile engine

Posted by GitBox <gi...@apache.org>.
jroesch commented on a change in pull request #7518:
URL: https://github.com/apache/tvm/pull/7518#discussion_r664940865



##########
File path: include/tvm/relay/attrs/annotation.h
##########
@@ -67,6 +67,18 @@ struct CompilerAttrs : public tvm::AttrsNode<CompilerAttrs> {
   }
 };
 
+/*!

Review comment:
       Fix the comment here

##########
File path: python/tvm/auto_scheduler/relay_integration.py
##########
@@ -376,6 +377,14 @@ def auto_schedule_topi(func_name, outs):
     return schedule
 
 
+@tvm._ffi.register_func("auto_scheduler.relay_integration.te_compiler_update_weights")

Review comment:
       Add comment here.

##########
File path: python/tvm/micro/model_library_format.py
##########
@@ -156,6 +156,7 @@ def _build_function_memory_map(function_metadata):
         2.) A global memory requirement if all functions are executed sequentially
     """
     device_max_workspace = dict()
+    print("TOTAL FUNCTION METADATA: ", function_metadata)

Review comment:
       Remove




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch commented on a change in pull request #7518: [RFC]TECompiler: Staged refactor and removal of compile engine

Posted by GitBox <gi...@apache.org>.
jroesch commented on a change in pull request #7518:
URL: https://github.com/apache/tvm/pull/7518#discussion_r664942913



##########
File path: src/runtime/graph_executor/graph_executor.cc
##########
@@ -395,6 +395,7 @@ void GraphExecutor::SetupOpExecs() {
 std::pair<std::function<void()>, std::shared_ptr<GraphExecutor::OpArgs> >
 GraphExecutor::CreateTVMOp(const TVMOpParam& param, const std::vector<DLTensor>& args,
                            size_t num_inputs) {
+  std::cout << param.func_name << std::endl;

Review comment:
       Remove this.

##########
File path: src/runtime/graph_executor/graph_executor.cc
##########
@@ -423,6 +424,8 @@ GraphExecutor::CreateTVMOp(const TVMOpParam& param, const std::vector<DLTensor>&
     auto fexec = [arg_ptr]() {
       DLTensor* from = static_cast<DLTensor*>(arg_ptr->arg_values[0].v_handle);
       DLTensor* to = static_cast<DLTensor*>(arg_ptr->arg_values[1].v_handle);
+      std::cout << "from: " << from->device.device_type << "to: " << to->device.device_type

Review comment:
       Remove this.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch commented on a change in pull request #7518: [RFC]TECompiler: Staged refactor and removal of compile engine

Posted by GitBox <gi...@apache.org>.
jroesch commented on a change in pull request #7518:
URL: https://github.com/apache/tvm/pull/7518#discussion_r609050655



##########
File path: src/relay/backend/graph_runtime_codegen.cc
##########
@@ -181,23 +186,56 @@ class GraphOpNode : public GraphNode {
   const std::string op_type_name_{"tvm_op"};
 };
 
-/*! \brief Code generator for graph runtime */
+/*! \brief Code generator for the graph runtime, produces a module containing the graph JSON,
+ * module, and parameters.
+ */
 class GraphRuntimeCodegen : public backend::MemoizedExprTranslator<std::vector<GraphNodeRef>> {
  public:
   GraphRuntimeCodegen(runtime::Module* mod, const TargetsMap& targets) : mod_(mod) {
-    compile_engine_ = CompileEngine::Global();
     targets_ = targets;
   }
 
   LoweredOutput Codegen(relay::Function func) {
-    auto pf = GetPackedFunc("relay.backend.GraphPlanMemory");
-    storage_device_map_ = (*pf)(func);
+    // Jared: why do we do this? just call C++ API.

Review comment:
       Yeah I think we need to move up the stack to address this, because everyone is only thinking about their own memory planning needs and not really looking at the large picture. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch commented on a change in pull request #7518: [RFC]TECompiler: Staged refactor and removal of compile engine

Posted by GitBox <gi...@apache.org>.
jroesch commented on a change in pull request #7518:
URL: https://github.com/apache/tvm/pull/7518#discussion_r609084369



##########
File path: src/relay/backend/te_compiler_cache.h
##########
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file relay/backend/tec_compiler_utils.h

Review comment:
       My naming was due to future plan to put just the cache data structures in this header and then later merge back into te_compiler once we remove compile engine




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch commented on a change in pull request #7518: [RFC]TECompiler: Staged refactor and removal of compile engine

Posted by GitBox <gi...@apache.org>.
jroesch commented on a change in pull request #7518:
URL: https://github.com/apache/tvm/pull/7518#discussion_r609083961



##########
File path: src/target/llvm/llvm_module.cc
##########
@@ -234,7 +238,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
       }
       funcs.push_back(f);
     }
-    ICHECK(funcs.size() > 0 || (could_have_linked_params && found_linked_params));
+    // ICHECK(funcs.size() > 0 || (could_have_linked_params && found_linked_params));

Review comment:
       I think we are allowed to have an empty model now? I can restore but I remember this no longer being true
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch commented on a change in pull request #7518: [RFC]TECompiler: Staged refactor and removal of compile engine

Posted by GitBox <gi...@apache.org>.
jroesch commented on a change in pull request #7518:
URL: https://github.com/apache/tvm/pull/7518#discussion_r664942503



##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -0,0 +1,760 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "te_compiler.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <tuple>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "te_compiler.h"
+#include "te_compiler_cache.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+// TODO(@jroesch, @csullivan): declare directly elsewhere
+backend::StaticMemoryPlan GraphPlanMemory(const Function& func);
+
+namespace tec {
+
+using namespace tvm::relay::transform;
+
+TVM_REGISTER_OBJECT_TYPE(TECompilerNode);
+
+class TECompilerImpl : public TECompilerNode {
+ public:
+  // Lower the function.
+  CachedFunc Lower(const CCacheKey& key, std::function<String(String)> mangle_fn) {
+    return LowerInternal(key, mangle_fn)->cached_func;
+  }
+
+  CachedFunc Lower(const CCacheKey& key, const String mod_name) {
+    auto mangle_fn = [mod_name](String name) {
+      std::cout << "inner mod name" << mod_name << std::endl;
+      return runtime::get_name_mangled(mod_name, name);
+    };
+
+    return Lower(key, mangle_fn);
+  }
+
+  // For now, build one module per function.
+  PackedFunc JIT(const CCacheKey& key) final {
+    auto mangle_fn = [](String name) { return name; };
+    CCacheValue value = LowerInternal(key, mangle_fn);
+    if (value->packed_func != nullptr) {
+      return value->packed_func;
+    }
+    auto m = build(value->cached_func->funcs, key->target, Target(nullptr));
+    value->packed_func = m.GetFunction(value->cached_func->prim_fn_var->name_hint);
+    return value->packed_func;
+  }
+
+  CachedFunc LowerShapeFunc(const CCacheKey& key) final {
+    return LowerShapeFuncInternal(key)->cached_func;
+  }
+
+  Map<String, IRModule> GetLoweredFunctions() {
+    Map<String, IRModule> lowered_functions;
+    for (const auto& it : cache_) {
+      auto source_func = it.first;
+      auto lowered_func = it.second;
+      auto target = source_func->target;
+
+      if (!lowered_functions.count(target->str())) {
+        lowered_functions.Set(target->str(), IRModule(Map<GlobalVar, BaseFunc>({})));
+      }
+
+      lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs);
+    }
+    return lowered_functions;
+  }
+
+  Array<tvm::runtime::Module> LowerExternalFunctions() {
+    Array<tvm::runtime::Module> ret;
+    std::unordered_map<std::string, std::string> cached_symbol;
+    std::vector<CCacheKey> cached_ext_funcs;
+    for (const auto& it : cache_) {
+      auto src_func = it.first->source_func;
+      ICHECK(src_func.defined());
+      if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
+        auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
+        std::string code_gen_name = code_gen.value();
+        cached_ext_funcs.push_back(it.first);
+
+        auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+        ICHECK(symbol_name.defined()) << "No external symbol is set for:\n"
+                                      << AsText(src_func, false);
+
+        std::string sn = symbol_name.value();
+        if (cached_symbol.count(sn)) {
+          cached_symbol[sn] = code_gen_name;
+        } else {
+          ICHECK_NE(sn, code_gen_name)
+              << "Found duplicated symbol: " << sn << " for: " << code_gen_name;
+        }
+
+        std::string ext_name = "relay.ext." + code_gen_name;
+        auto pf = tvm::runtime::Registry::Get(ext_name);
+        ICHECK(pf) << "Failed to find the codegen tool for " << ext_name;
+        // No need to keep compiler attribute at this point, functions have been
+        // extracted for specific codegen.
+        src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue<ObjectRef>());
+        runtime::Module ext_mod = (*pf)(src_func);
+
+        ICHECK(ext_mod.defined()) << "No external runtime is generated.";
+        ret.push_back(ext_mod);
+      }
+    }
+
+    // No need to cache external functions as we collected them all to create
+    // external runtime modules.
+    for (const auto& it : cached_ext_funcs) {
+      cache_.erase(it);
+    }
+    return ret;
+  }
+
+  void Clear() final { cache_.clear(); }
+
+  // List all items in the cache.
+  Array<ObjectRef> ListItems() {
+    std::lock_guard<std::mutex> lock(mutex_);
+    Array<ObjectRef> items;
+    for (auto& kv : cache_) {
+      items.push_back(kv.first);
+      items.push_back(kv.second);
+    }
+    return items;
+  }
+
+  /*!
+   * \brief Get the cache key of the function that is being lowered currently
+   * \return the cache key
+   */
+  CCacheKey GetCurrentCCacheKey() { return cur_ccache_key_; }
+
+ private:
+  // implement lowered func
+  CCacheValue LowerInternal(const CCacheKey& key, std::function<String(String)> mangle_fn) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    CCacheValue value;
+    auto it = cache_.find(key);
+    if (it != cache_.end()) {
+      it->second->use_count += 1;
+      if (it->second->cached_func.defined()) return it->second;
+      value = it->second;
+    } else {
+      value = CCacheValue(make_object<CCacheValueNode>());
+      value->use_count = 1;
+      cache_[key] = value;
+    }
+    cur_ccache_key_ = key;
+
+    // No need to lower external functions for now. We will invoke the external
+    // codegen tool once and lower all functions together.
+    if (key->source_func->GetAttr<String>(attr::kCompiler).defined()) {
+      auto ir_module = IRModule();
+      const auto name_node = key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+      ICHECK(name_node.defined()) << "External function has not been attached a name yet.";
+      auto func_name = GetUniqueName(name_node.value(), &name_map_);
+      auto target = Target("ext_dev");
+      auto global_var = GlobalVar(func_name);
+      global_var->checked_type_ = key->source_func->checked_type();
+      value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule(), {}, ir_module);
+      return value;
+    }
+
+    // Enforce use the target.
+    With<Target> target_scope(key->target);
+
+    ICHECK(!value->cached_func.defined());
+    auto cfunc = PrimFuncFor(key->source_func, key->target, [&](std::string name) {
+      auto mangled = mangle_fn(name);
+      std::cout << "Mangled: " << mangled << std::endl;
+      return GetUniqueName(mangled, &name_map_);
+    });
+
+    // Skip lowering for device copy node.
+    const Expr body = (key->source_func)->body;
+    if (const CallNode* call_node = body.as<CallNode>()) {
+      if (call_node->attrs.as<DeviceCopyAttrs>()) {
+        value->cached_func = cfunc;
+        return value;
+      }
+    }
+
+    // NOTE: array will copy on write.
+    Array<te::Tensor> all_args = Array<te::Tensor>(cfunc->inputs);
+    for (te::Tensor arg : cfunc->outputs) {
+      all_args.push_back(arg);
+    }
+
+    std::unordered_map<te::Tensor, tir::Buffer> binds;
+    auto func_name = cfunc->prim_fn_var->name_hint;
+    cfunc->funcs->Update(tvm::LowerSchedule(cfunc->schedule, all_args, func_name, binds));
+    value->cached_func = cfunc;
+    return value;
+  }
+
+  // implement lowered shape func
+  CCacheValue LowerShapeFuncInternal(const CCacheKey& key) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    CCacheValue value;
+    auto it = shape_func_cache_.find(key);
+    if (it != shape_func_cache_.end()) {
+      it->second->use_count += 1;
+      if (it->second->cached_func.defined()) return it->second;
+      value = it->second;
+    } else {
+      value = CCacheValue(make_object<CCacheValueNode>());
+      value->use_count = 0;
+      shape_func_cache_[key] = value;
+    }
+    // Enforce use the target.
+    With<Target> target_scope(key->target);
+
+    ICHECK(!value->cached_func.defined());
+
+    using tvm::transform::PassContext;
+    With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
+    auto cached_func = ShapeFuncFor(key->source_func, key->target, [&](std::string name) {
+      return GetUniqueName(name, &name_map_);
+    });
+
+    value->cached_func = cached_func;
+    return value;
+  }
+
+  std::unordered_map<std::string, int> GetOpWeights() {
+    std::unordered_map<std::string, int> weights;
+    for (auto pair : cache_) {
+      auto value = pair.second;
+      auto name = value->cached_func->prim_fn_var->name_hint;
+      weights[name] = value->use_count;
+    }
+    return weights;
+  }
+
+  /*! \brief compiler cache lock*/
+  std::mutex mutex_;
+  /*! \brief internal name map to get an unique name */
+  std::unordered_map<std::string, int> name_map_;
+  /*! \brief internal compiler cache */
+  std::unordered_map<CCacheKey, CCacheValue> cache_;
+  /*! \brief internal compiler cache for shape funcs */
+  std::unordered_map<CCacheKey, CCacheValue> shape_func_cache_;
+  /*! \brief the cache key of the function that is being lowered currently*/
+  CCacheKey cur_ccache_key_;
+};
+
+TECompiler::TECompiler() {
+  auto object = make_object<TECompilerImpl>();
+  data_ = object;
+}
+
+using AnalysisRemapping = std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual>;
+
+std::tuple<bool, int, int> IsDeviceCopy(const Function& func) {
+  if (auto call_node = func->body.as<CallNode>()) {
+    if (auto op_node = call_node->op.as<OpNode>()) {
+      if (op_node->name == "device_copy") {
+        auto attrs = call_node->attrs.as<DeviceCopyAttrs>();
+        auto dst = attrs->dst_dev_type;
+        auto src = attrs->src_dev_type;
+        return std::tuple<bool, int, int>(true, src, dst);
+      }
+    }
+  }
+
+  return std::tuple<bool, int, int>(false, -1, -1);
+}
+
+class LowerTensorExpr : public ExprMutator {
+ public:
+  LowerTensorExpr(const IRModule& module, const TargetMap& targets, const DeviceMap& device_ctx_map,
+                  ProcessFn process_fn, AnalysisRemapping* prim_fn_to_call,
+                  const String& module_name, TECompiler compiler)
+      : module_(module),
+        targets_(targets),
+        device_context_map_(device_ctx_map),
+        process_fn(process_fn),
+        prim_fn_to_call(prim_fn_to_call),
+        module_name_(module_name),
+        compiler_(compiler) {}
+
+  Expr VisitExpr_(const CallNode* call) override {
+    Call expr = GetRef<Call>(call);
+    Function func;
+
+    if (call->op.as<FunctionNode>()) {
+      func = GetRef<Function>(call->op.as<FunctionNode>());
+    } else {
+      return ExprMutator::VisitExpr_(call);
+    }
+
+    if (!func->HasNonzeroAttr(attr::kPrimitive)) {
+      // Provide a callback hook which allows one-level up code generators to
+      // act when we process a function.
+      this->process_fn(func);
+      return ExprMutator::VisitExpr_(call);
+    }
+
+    // Process inputs.
+    Array<Expr> args;
+    for (size_t i = 0; i < expr->args.size(); i++) {
+      args.push_back(VisitExpr(expr->args[i]));
+    }
+
+    Target target;
+
+    if (func->GetAttr<String>(attr::kCompiler).defined()) {
+      target = Target("ext_dev");
+      CCacheKey key = CCacheKey(func, target);
+      CachedFunc ext_func = compiler_->Lower(key, module_name_);
+      ICHECK(ext_func.defined()) << "Lowering returned undefined function for "
+                                 << ext_func->prim_fn_var->name_hint;
+
+      Map<GlobalVar, tir::PrimFunc> prim_fns;
+
+      for (auto prim_fn : ext_func->funcs->functions) {
+        CHECK(prim_fn.second.as<tir::PrimFuncNode>()) << "must be a prim fn";
+        prim_fns.Set(prim_fn.first, Downcast<tir::PrimFunc>(prim_fn.second));
+      }
+
+      relay::Function func_with_metadata = func;
+      func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", ext_func->prim_fn_var);
+      func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns);
+      func_with_metadata = WithAttr(func_with_metadata, "target", ext_func->target);
+
+      // Provide a callback hook which allows one-level up code generators to
+      // act when we process a function.
+      this->process_fn(func_with_metadata);
+
+      auto ret_call = Call(ext_func->prim_fn_var, args, {});
+      (*prim_fn_to_call)[func] = ret_call;
+      return std::move(ret_call);
+    }
+
+    ICHECK_GE(device_context_map_.count(expr), 0)
+        << "Could not find an entry in the device context map for " << PrettyPrint(expr)
+        << "The memory planning was either not performed for this precise node, or there is bug "
+           "in the memory planner.";
+
+    auto& device_context = this->device_context_map_[expr];
+    auto call_dev_type = device_context.device_type;
+
+    // Non-External Relay Function
+    if (targets_.size() == 1) {
+      // The homogeneous execution case, we should only have one target
+      // so we just grab it.
+      const auto& it = targets_.begin();
+      target = (*it).second;
+    } else {
+      // The heterogeneous execution case we have multiple targets
+      // in this case.
+      //
+      // We need to identify the target and translate.
+      std::string call_dev_name;
+      if (call_dev_type == 0) {
+        call_dev_name = "llvm";
+        call_dev_type = kDLCPU;
+      } else {
+        call_dev_name = ::tvm::runtime::DeviceName(call_dev_type);
+      }
+
+      if (targets_.count(call_dev_type) == 0) {
+        std::stringstream msg;
+        msg << "No target is specified for provided device name: `" << call_dev_name << "`\n\n";
+        msg << call_dev_name << " mapped to device type (" << call_dev_type
+            << ") which was not found in the target map.\n";
+        msg << "Availible targets: \n";
+        for (auto target : targets_) {
+          msg << "  " << target.first << "-> " << target.second << "\n";
+        }
+        LOG(FATAL) << msg.str();
+      }
+
+      target = targets_[call_dev_type];
+    }
+
+    CCacheKey key = CCacheKey(func, target);
+    CachedFunc lowered_func = compiler_->Lower(key, module_name_);
+
+    Map<GlobalVar, tir::PrimFunc> prim_fns;
+
+    for (auto prim_fn : lowered_func->funcs->functions) {
+      CHECK(prim_fn.second.as<tir::PrimFuncNode>()) << "must be a prim fn";
+      prim_fns.Set(prim_fn.first, Downcast<tir::PrimFunc>(prim_fn.second));
+    }
+
+    // TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT
+    relay::Function func_with_metadata = func;
+    func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", lowered_func->prim_fn_var);
+    func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns);
+    func_with_metadata = WithAttr(func_with_metadata, "target", lowered_func->target);
+
+    // Provide a callback hook which allows one-level up code generators to
+    // act when we process a function.
+    this->process_fn(func_with_metadata);
+
+    auto tir_call_attrs = make_object<TIRCallAttrs>();
+    if (func->HasNonzeroAttr(attr::kReshapeOnly)) {
+      tir_call_attrs->metadata.Set(attr::kReshapeOnly, tvm::Integer(1));
+    }
+
+    auto device_copy = IsDeviceCopy(func);
+    if (std::get<0>(device_copy)) {
+      std::cout << "DeviceCopy" << std::endl;
+      auto source_device = std::get<1>(device_copy);
+      auto dst_device = std::get<2>(device_copy);
+      tir_call_attrs->metadata.Set("source_device", tvm::Integer(source_device));
+      tir_call_attrs->metadata.Set("dst_device", tvm::Integer(dst_device));
+    }
+
+    std::cout << "Function: " << func << std::endl;
+
+    tir_call_attrs->metadata.Set("relay_attrs", func->attrs);
+
+    Expr ret_call = Call(lowered_func->prim_fn_var, args, Attrs(tir_call_attrs));
+    (*prim_fn_to_call)[func] = ret_call;
+    return ret_call;
+  }
+
+  IRModule module_;
+  TargetMap targets_;
+  DeviceMap device_context_map_;
+  ProcessFn process_fn;
+  AnalysisRemapping* prim_fn_to_call;

Review comment:
       Remove this.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch commented on a change in pull request #7518: [RFC]TECompiler: Staged refactor and removal of compile engine

Posted by GitBox <gi...@apache.org>.
jroesch commented on a change in pull request #7518:
URL: https://github.com/apache/tvm/pull/7518#discussion_r609055722



##########
File path: src/relay/backend/compile_engine.cc
##########
@@ -731,50 +170,50 @@ class CompileEngineImpl : public CompileEngineNode {
     // No need to lower external functions for now. We will invoke the external
     // codegen tool once and lower all functions together.
     if (key->source_func->GetAttr<String>(attr::kCompiler).defined()) {
-      auto cache_node = make_object<CachedFuncNode>();
+      auto ir_module = IRModule();
       const auto name_node = key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
       ICHECK(name_node.defined()) << "External function has not been attached a name yet.";
-      cache_node->func_name = std::string(name_node.value());
-      cache_node->target = Target("ext_dev");
-      cache_node->funcs->Add(GlobalVar(cache_node->func_name), key->source_func);
-      value->cached_func = CachedFunc(cache_node);
+      auto func_name = std::string(name_node.value());
+      auto target = Target("ext_dev");
+      auto global_var = GlobalVar(func_name);
+      global_var->checked_type_ = key->source_func->checked_type();
+      ir_module->Add(global_var, key->source_func);
+      value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule(), {}, ir_module);
       return value;
     }
+
     // Enforce use the target.
     With<Target> target_scope(key->target);
 
     ICHECK(!value->cached_func.defined());
-    auto cfunc = CreateSchedule(key->source_func, key->target);
-    auto cache_node = make_object<CachedFuncNode>(*(cfunc.operator->()));
+    auto cfunc = PrimFuncFor(key->source_func, key->target,
+                             [&](std::string name) { return GetUniqueName(name, name_map_); });
 
     // Skip lowering for device copy node.
     const Expr body = (key->source_func)->body;
     if (const CallNode* call_node = body.as<CallNode>()) {
       if (call_node->attrs.as<DeviceCopyAttrs>()) {
-        value->cached_func = CachedFunc(cache_node);
+        value->cached_func = cfunc;
         return value;
       }
     }
 
-    cache_node->func_name = GetUniqueName(cache_node->func_name);
     // NOTE: array will copy on write.
-    Array<te::Tensor> all_args = cache_node->inputs;
-    for (te::Tensor arg : cache_node->outputs) {
+    Array<te::Tensor> all_args = Array<te::Tensor>(cfunc->inputs);
+    for (te::Tensor arg : cfunc->outputs) {
       all_args.push_back(arg);
     }
-    // lower the function
-    if (const auto* f = runtime::Registry::Get("relay.backend.lower")) {
-      cache_node->funcs = (*f)(cfunc->schedule, all_args, cache_node->func_name, key->source_func);
-    } else {
-      using tvm::transform::PassContext;
-      With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
 
-      std::unordered_map<te::Tensor, tir::Buffer> binds;
-      cache_node->funcs = tvm::lower(cfunc->schedule, all_args, cache_node->func_name, binds);
-    }
-    value->cached_func = CachedFunc(cache_node);
+    using tvm::transform::PassContext;
+    With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
+
+    std::unordered_map<te::Tensor, tir::Buffer> binds;
+    auto func_name = cfunc->prim_fn_var->name_hint;
+    cfunc->funcs->Update(tvm::lower(cfunc->schedule, all_args, func_name, binds));

Review comment:
       We need to fix this, this is BAD imo. Should we file a follow up issue? I think we should remove the Python API and force it into C++. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch commented on a change in pull request #7518: [RFC]TECompiler: Staged refactor and removal of compile engine

Posted by GitBox <gi...@apache.org>.
jroesch commented on a change in pull request #7518:
URL: https://github.com/apache/tvm/pull/7518#discussion_r664942399



##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -0,0 +1,760 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "te_compiler.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <tuple>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "te_compiler.h"
+#include "te_compiler_cache.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+// TODO(@jroesch, @csullivan): declare directly elsewhere
+backend::StaticMemoryPlan GraphPlanMemory(const Function& func);
+
+namespace tec {
+
+using namespace tvm::relay::transform;
+
+TVM_REGISTER_OBJECT_TYPE(TECompilerNode);
+
+class TECompilerImpl : public TECompilerNode {
+ public:
+  // Lower the function.
+  CachedFunc Lower(const CCacheKey& key, std::function<String(String)> mangle_fn) {
+    return LowerInternal(key, mangle_fn)->cached_func;
+  }
+
+  CachedFunc Lower(const CCacheKey& key, const String mod_name) {
+    auto mangle_fn = [mod_name](String name) {
+      std::cout << "inner mod name" << mod_name << std::endl;
+      return runtime::get_name_mangled(mod_name, name);
+    };
+
+    return Lower(key, mangle_fn);
+  }
+
+  // For now, build one module per function.
+  PackedFunc JIT(const CCacheKey& key) final {
+    auto mangle_fn = [](String name) { return name; };
+    CCacheValue value = LowerInternal(key, mangle_fn);
+    if (value->packed_func != nullptr) {
+      return value->packed_func;
+    }
+    auto m = build(value->cached_func->funcs, key->target, Target(nullptr));
+    value->packed_func = m.GetFunction(value->cached_func->prim_fn_var->name_hint);
+    return value->packed_func;
+  }
+
+  CachedFunc LowerShapeFunc(const CCacheKey& key) final {
+    return LowerShapeFuncInternal(key)->cached_func;
+  }
+
+  Map<String, IRModule> GetLoweredFunctions() {
+    Map<String, IRModule> lowered_functions;
+    for (const auto& it : cache_) {
+      auto source_func = it.first;
+      auto lowered_func = it.second;
+      auto target = source_func->target;
+
+      if (!lowered_functions.count(target->str())) {
+        lowered_functions.Set(target->str(), IRModule(Map<GlobalVar, BaseFunc>({})));
+      }
+
+      lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs);
+    }
+    return lowered_functions;
+  }
+
+  Array<tvm::runtime::Module> LowerExternalFunctions() {
+    Array<tvm::runtime::Module> ret;
+    std::unordered_map<std::string, std::string> cached_symbol;
+    std::vector<CCacheKey> cached_ext_funcs;
+    for (const auto& it : cache_) {
+      auto src_func = it.first->source_func;
+      ICHECK(src_func.defined());
+      if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
+        auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
+        std::string code_gen_name = code_gen.value();
+        cached_ext_funcs.push_back(it.first);
+
+        auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+        ICHECK(symbol_name.defined()) << "No external symbol is set for:\n"
+                                      << AsText(src_func, false);
+
+        std::string sn = symbol_name.value();
+        if (cached_symbol.count(sn)) {
+          cached_symbol[sn] = code_gen_name;
+        } else {
+          ICHECK_NE(sn, code_gen_name)
+              << "Found duplicated symbol: " << sn << " for: " << code_gen_name;
+        }
+
+        std::string ext_name = "relay.ext." + code_gen_name;
+        auto pf = tvm::runtime::Registry::Get(ext_name);
+        ICHECK(pf) << "Failed to find the codegen tool for " << ext_name;
+        // No need to keep compiler attribute at this point, functions have been
+        // extracted for specific codegen.
+        src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue<ObjectRef>());
+        runtime::Module ext_mod = (*pf)(src_func);
+
+        ICHECK(ext_mod.defined()) << "No external runtime is generated.";
+        ret.push_back(ext_mod);
+      }
+    }
+
+    // No need to cache external functions as we collected them all to create
+    // external runtime modules.
+    for (const auto& it : cached_ext_funcs) {
+      cache_.erase(it);
+    }
+    return ret;
+  }
+
+  void Clear() final { cache_.clear(); }
+
+  // List all items in the cache.
+  Array<ObjectRef> ListItems() {
+    std::lock_guard<std::mutex> lock(mutex_);
+    Array<ObjectRef> items;
+    for (auto& kv : cache_) {
+      items.push_back(kv.first);
+      items.push_back(kv.second);
+    }
+    return items;
+  }
+
+  /*!
+   * \brief Get the cache key of the function that is being lowered currently
+   * \return the cache key
+   */
+  CCacheKey GetCurrentCCacheKey() { return cur_ccache_key_; }
+
+ private:
+  // implement lowered func
+  CCacheValue LowerInternal(const CCacheKey& key, std::function<String(String)> mangle_fn) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    CCacheValue value;
+    auto it = cache_.find(key);
+    if (it != cache_.end()) {
+      it->second->use_count += 1;
+      if (it->second->cached_func.defined()) return it->second;
+      value = it->second;
+    } else {
+      value = CCacheValue(make_object<CCacheValueNode>());
+      value->use_count = 1;
+      cache_[key] = value;
+    }
+    cur_ccache_key_ = key;
+
+    // No need to lower external functions for now. We will invoke the external
+    // codegen tool once and lower all functions together.
+    if (key->source_func->GetAttr<String>(attr::kCompiler).defined()) {
+      auto ir_module = IRModule();
+      const auto name_node = key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+      ICHECK(name_node.defined()) << "External function has not been attached a name yet.";
+      auto func_name = GetUniqueName(name_node.value(), &name_map_);
+      auto target = Target("ext_dev");
+      auto global_var = GlobalVar(func_name);
+      global_var->checked_type_ = key->source_func->checked_type();
+      value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule(), {}, ir_module);
+      return value;
+    }
+
+    // Enforce use the target.
+    With<Target> target_scope(key->target);
+
+    ICHECK(!value->cached_func.defined());
+    auto cfunc = PrimFuncFor(key->source_func, key->target, [&](std::string name) {
+      auto mangled = mangle_fn(name);
+      std::cout << "Mangled: " << mangled << std::endl;
+      return GetUniqueName(mangled, &name_map_);
+    });
+
+    // Skip lowering for device copy node.
+    const Expr body = (key->source_func)->body;
+    if (const CallNode* call_node = body.as<CallNode>()) {
+      if (call_node->attrs.as<DeviceCopyAttrs>()) {
+        value->cached_func = cfunc;
+        return value;
+      }
+    }
+
+    // NOTE: array will copy on write.
+    Array<te::Tensor> all_args = Array<te::Tensor>(cfunc->inputs);
+    for (te::Tensor arg : cfunc->outputs) {
+      all_args.push_back(arg);
+    }
+
+    std::unordered_map<te::Tensor, tir::Buffer> binds;
+    auto func_name = cfunc->prim_fn_var->name_hint;
+    cfunc->funcs->Update(tvm::LowerSchedule(cfunc->schedule, all_args, func_name, binds));
+    value->cached_func = cfunc;
+    return value;
+  }
+
+  // implement lowered shape func
+  CCacheValue LowerShapeFuncInternal(const CCacheKey& key) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    CCacheValue value;
+    auto it = shape_func_cache_.find(key);
+    if (it != shape_func_cache_.end()) {
+      it->second->use_count += 1;
+      if (it->second->cached_func.defined()) return it->second;
+      value = it->second;
+    } else {
+      value = CCacheValue(make_object<CCacheValueNode>());
+      value->use_count = 0;
+      shape_func_cache_[key] = value;
+    }
+    // Enforce use the target.
+    With<Target> target_scope(key->target);
+
+    ICHECK(!value->cached_func.defined());
+
+    using tvm::transform::PassContext;
+    With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
+    auto cached_func = ShapeFuncFor(key->source_func, key->target, [&](std::string name) {
+      return GetUniqueName(name, &name_map_);
+    });
+
+    value->cached_func = cached_func;
+    return value;
+  }
+
+  std::unordered_map<std::string, int> GetOpWeights() {
+    std::unordered_map<std::string, int> weights;
+    for (auto pair : cache_) {
+      auto value = pair.second;
+      auto name = value->cached_func->prim_fn_var->name_hint;
+      weights[name] = value->use_count;
+    }
+    return weights;
+  }
+
+  /*! \brief compiler cache lock*/
+  std::mutex mutex_;
+  /*! \brief internal name map to get an unique name */
+  std::unordered_map<std::string, int> name_map_;
+  /*! \brief internal compiler cache */
+  std::unordered_map<CCacheKey, CCacheValue> cache_;
+  /*! \brief internal compiler cache for shape funcs */
+  std::unordered_map<CCacheKey, CCacheValue> shape_func_cache_;
+  /*! \brief the cache key of the function that is being lowered currently*/
+  CCacheKey cur_ccache_key_;
+};
+
+TECompiler::TECompiler() {
+  auto object = make_object<TECompilerImpl>();
+  data_ = object;
+}
+
+using AnalysisRemapping = std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual>;
+
+std::tuple<bool, int, int> IsDeviceCopy(const Function& func) {
+  if (auto call_node = func->body.as<CallNode>()) {
+    if (auto op_node = call_node->op.as<OpNode>()) {
+      if (op_node->name == "device_copy") {
+        auto attrs = call_node->attrs.as<DeviceCopyAttrs>();
+        auto dst = attrs->dst_dev_type;
+        auto src = attrs->src_dev_type;
+        return std::tuple<bool, int, int>(true, src, dst);
+      }
+    }
+  }
+
+  return std::tuple<bool, int, int>(false, -1, -1);
+}
+
+class LowerTensorExpr : public ExprMutator {
+ public:
+  LowerTensorExpr(const IRModule& module, const TargetMap& targets, const DeviceMap& device_ctx_map,
+                  ProcessFn process_fn, AnalysisRemapping* prim_fn_to_call,
+                  const String& module_name, TECompiler compiler)
+      : module_(module),
+        targets_(targets),
+        device_context_map_(device_ctx_map),
+        process_fn(process_fn),
+        prim_fn_to_call(prim_fn_to_call),
+        module_name_(module_name),
+        compiler_(compiler) {}
+
+  Expr VisitExpr_(const CallNode* call) override {
+    Call expr = GetRef<Call>(call);
+    Function func;
+
+    if (call->op.as<FunctionNode>()) {
+      func = GetRef<Function>(call->op.as<FunctionNode>());
+    } else {
+      return ExprMutator::VisitExpr_(call);
+    }
+
+    if (!func->HasNonzeroAttr(attr::kPrimitive)) {
+      // Provide a callback hook which allows one-level up code generators to
+      // act when we process a function.
+      this->process_fn(func);
+      return ExprMutator::VisitExpr_(call);
+    }
+
+    // Process inputs.
+    Array<Expr> args;
+    for (size_t i = 0; i < expr->args.size(); i++) {
+      args.push_back(VisitExpr(expr->args[i]));
+    }
+
+    Target target;
+
+    if (func->GetAttr<String>(attr::kCompiler).defined()) {
+      target = Target("ext_dev");
+      CCacheKey key = CCacheKey(func, target);
+      CachedFunc ext_func = compiler_->Lower(key, module_name_);
+      ICHECK(ext_func.defined()) << "Lowering returned undefined function for "
+                                 << ext_func->prim_fn_var->name_hint;
+
+      Map<GlobalVar, tir::PrimFunc> prim_fns;
+
+      for (auto prim_fn : ext_func->funcs->functions) {
+        CHECK(prim_fn.second.as<tir::PrimFuncNode>()) << "must be a prim fn";
+        prim_fns.Set(prim_fn.first, Downcast<tir::PrimFunc>(prim_fn.second));
+      }
+
+      relay::Function func_with_metadata = func;
+      func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", ext_func->prim_fn_var);
+      func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns);
+      func_with_metadata = WithAttr(func_with_metadata, "target", ext_func->target);
+
+      // Provide a callback hook which allows one-level up code generators to
+      // act when we process a function.
+      this->process_fn(func_with_metadata);
+
+      auto ret_call = Call(ext_func->prim_fn_var, args, {});
+      (*prim_fn_to_call)[func] = ret_call;
+      return std::move(ret_call);
+    }
+
+    ICHECK_GE(device_context_map_.count(expr), 0)
+        << "Could not find an entry in the device context map for " << PrettyPrint(expr)
+        << "The memory planning was either not performed for this precise node, or there is bug "
+           "in the memory planner.";
+
+    auto& device_context = this->device_context_map_[expr];
+    auto call_dev_type = device_context.device_type;
+
+    // Non-External Relay Function
+    if (targets_.size() == 1) {
+      // The homogeneous execution case, we should only have one target
+      // so we just grab it.
+      const auto& it = targets_.begin();
+      target = (*it).second;
+    } else {
+      // The heterogeneous execution case we have multiple targets
+      // in this case.
+      //
+      // We need to identify the target and translate.
+      std::string call_dev_name;
+      if (call_dev_type == 0) {
+        call_dev_name = "llvm";
+        call_dev_type = kDLCPU;
+      } else {
+        call_dev_name = ::tvm::runtime::DeviceName(call_dev_type);
+      }
+
+      if (targets_.count(call_dev_type) == 0) {
+        std::stringstream msg;
+        msg << "No target is specified for provided device name: `" << call_dev_name << "`\n\n";
+        msg << call_dev_name << " mapped to device type (" << call_dev_type
+            << ") which was not found in the target map.\n";
+        msg << "Availible targets: \n";
+        for (auto target : targets_) {
+          msg << "  " << target.first << "-> " << target.second << "\n";
+        }
+        LOG(FATAL) << msg.str();
+      }
+
+      target = targets_[call_dev_type];
+    }
+
+    CCacheKey key = CCacheKey(func, target);
+    CachedFunc lowered_func = compiler_->Lower(key, module_name_);
+
+    Map<GlobalVar, tir::PrimFunc> prim_fns;
+
+    for (auto prim_fn : lowered_func->funcs->functions) {
+      CHECK(prim_fn.second.as<tir::PrimFuncNode>()) << "must be a prim fn";
+      prim_fns.Set(prim_fn.first, Downcast<tir::PrimFunc>(prim_fn.second));
+    }
+
+    // TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT
+    relay::Function func_with_metadata = func;
+    func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", lowered_func->prim_fn_var);
+    func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns);
+    func_with_metadata = WithAttr(func_with_metadata, "target", lowered_func->target);
+
+    // Provide a callback hook which allows one-level up code generators to
+    // act when we process a function.
+    this->process_fn(func_with_metadata);
+
+    auto tir_call_attrs = make_object<TIRCallAttrs>();
+    if (func->HasNonzeroAttr(attr::kReshapeOnly)) {
+      tir_call_attrs->metadata.Set(attr::kReshapeOnly, tvm::Integer(1));
+    }
+
+    auto device_copy = IsDeviceCopy(func);
+    if (std::get<0>(device_copy)) {
+      std::cout << "DeviceCopy" << std::endl;
+      auto source_device = std::get<1>(device_copy);
+      auto dst_device = std::get<2>(device_copy);
+      tir_call_attrs->metadata.Set("source_device", tvm::Integer(source_device));
+      tir_call_attrs->metadata.Set("dst_device", tvm::Integer(dst_device));
+    }
+
+    std::cout << "Function: " << func << std::endl;

Review comment:
       Remove this




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tkonolige commented on a change in pull request #7518: [RFC]TECompiler: Staged refactor and removal of compile engine

Posted by GitBox <gi...@apache.org>.
tkonolige commented on a change in pull request #7518:
URL: https://github.com/apache/tvm/pull/7518#discussion_r602419800



##########
File path: src/relay/backend/graph_runtime_codegen.cc
##########
@@ -181,23 +186,56 @@ class GraphOpNode : public GraphNode {
   const std::string op_type_name_{"tvm_op"};
 };
 
-/*! \brief Code generator for graph runtime */
+/*! \brief Code generator for the graph runtime, produces a module containing the graph JSON,
+ * module, and parameters.
+ */
 class GraphRuntimeCodegen : public backend::MemoizedExprTranslator<std::vector<GraphNodeRef>> {
  public:
   GraphRuntimeCodegen(runtime::Module* mod, const TargetsMap& targets) : mod_(mod) {
-    compile_engine_ = CompileEngine::Global();
     targets_ = targets;
   }
 
   LoweredOutput Codegen(relay::Function func) {
-    auto pf = GetPackedFunc("relay.backend.GraphPlanMemory");
-    storage_device_map_ = (*pf)(func);
+    // Jared: why do we do this? just call C++ API.
+    // auto pf = GetPackedFunc("relay.backend.GraphPlanMemory");
+    // storage_device_map_ = (*pf)(func);
+    storage_device_map_ = GraphPlanMemory(func);
+
+    // This first phase moves from implicit use of compile engine,
+    // to instead the lower the incoming IRModule, and then performing

Review comment:
       ```suggestion
       // to instead lowering the incoming IRModule, and then performing
   ```

##########
File path: src/relay/backend/graph_runtime_codegen.cc
##########
@@ -181,23 +186,56 @@ class GraphOpNode : public GraphNode {
   const std::string op_type_name_{"tvm_op"};
 };
 
-/*! \brief Code generator for graph runtime */
+/*! \brief Code generator for the graph runtime, produces a module containing the graph JSON,
+ * module, and parameters.
+ */
 class GraphRuntimeCodegen : public backend::MemoizedExprTranslator<std::vector<GraphNodeRef>> {
  public:
   GraphRuntimeCodegen(runtime::Module* mod, const TargetsMap& targets) : mod_(mod) {
-    compile_engine_ = CompileEngine::Global();
     targets_ = targets;
   }
 
   LoweredOutput Codegen(relay::Function func) {
-    auto pf = GetPackedFunc("relay.backend.GraphPlanMemory");
-    storage_device_map_ = (*pf)(func);
+    // Jared: why do we do this? just call C++ API.
+    // auto pf = GetPackedFunc("relay.backend.GraphPlanMemory");
+    // storage_device_map_ = (*pf)(func);
+    storage_device_map_ = GraphPlanMemory(func);
+
+    // This first phase moves from implicit use of compile engine,
+    // to instead the lower the incoming IRModule, and then performing
+    // the pre-exiting graph runtime code generation phase.

Review comment:
       ```suggestion
       // the pre-existing graph runtime code generation phase.
   ```

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -0,0 +1,372 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "te_compiler.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "te_compiler_cache.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+using namespace tvm::relay::transform;
+
+TVM_REGISTER_OBJECT_TYPE(TECompilerNode);
+
+class TECompilerImpl : public TECompilerNode {
+ public:
+  // Lower the function.
+  CachedFunc Lower(const CCacheKey& key) { return LowerInternal(key)->cached_func; }
+
+  // For now, build one module per function.
+  PackedFunc JIT(const CCacheKey& key) final {
+    CCacheValue value = LowerInternal(key);
+    if (value->packed_func != nullptr) {
+      return value->packed_func;
+    }
+    auto m = build(value->cached_func->funcs, key->target, Target(nullptr));
+    value->packed_func = m.GetFunction(value->cached_func->prim_fn_var->name_hint);
+    return value->packed_func;
+  }
+
+  CachedFunc LowerShapeFunc(const CCacheKey& key) final {
+    return LowerShapeFuncInternal(key)->cached_func;
+  }
+
+  Map<String, IRModule> GetLoweredFunctions() {
+    Map<String, IRModule> lowered_functions;
+    for (const auto& it : cache_) {
+      auto source_func = it.first;
+      auto lowered_func = it.second;
+      auto target = source_func->target;
+
+      if (!lowered_functions.count(target->str())) {
+        lowered_functions.Set(target->str(), IRModule(Map<GlobalVar, BaseFunc>({})));
+      }
+
+      lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs);
+    }
+    return lowered_functions;
+  }
+
+  Array<tvm::runtime::Module> LowerExternalFunctions() {
+    Array<tvm::runtime::Module> ret;
+    std::unordered_map<std::string, std::string> cached_symbol;
+    std::vector<CCacheKey> cached_ext_funcs;
+    for (const auto& it : cache_) {
+      auto src_func = it.first->source_func;
+      ICHECK(src_func.defined());
+      if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
+        auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
+        ICHECK(code_gen.defined()) << "No external codegen is set";
+        std::string code_gen_name = code_gen.value();
+        cached_ext_funcs.push_back(it.first);
+
+        auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+        ICHECK(symbol_name.defined()) << "No external symbol is set for:\n"
+                                      << AsText(src_func, false);
+
+        std::string sn = symbol_name.value();
+        if (cached_symbol.count(sn)) {
+          cached_symbol[sn] = code_gen_name;
+        } else {
+          ICHECK_NE(sn, code_gen_name)
+              << "Found duplicated symbol: " << sn << " for: " << code_gen_name;
+        }
+
+        std::string ext_name = "relay.ext." + code_gen_name;
+        auto pf = tvm::runtime::Registry::Get(ext_name);
+        ICHECK(pf) << "Failed to find the codegen tool for " << ext_name << "\n";

Review comment:
       ```suggestion
           ICHECK(pf) << "Failed to find the codegen tool for " << ext_name;
   ```

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -0,0 +1,372 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "te_compiler.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "te_compiler_cache.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+using namespace tvm::relay::transform;
+
+TVM_REGISTER_OBJECT_TYPE(TECompilerNode);
+
+class TECompilerImpl : public TECompilerNode {
+ public:
+  // Lower the function.
+  CachedFunc Lower(const CCacheKey& key) { return LowerInternal(key)->cached_func; }
+
+  // For now, build one module per function.
+  PackedFunc JIT(const CCacheKey& key) final {
+    CCacheValue value = LowerInternal(key);
+    if (value->packed_func != nullptr) {
+      return value->packed_func;
+    }
+    auto m = build(value->cached_func->funcs, key->target, Target(nullptr));
+    value->packed_func = m.GetFunction(value->cached_func->prim_fn_var->name_hint);
+    return value->packed_func;
+  }
+
+  CachedFunc LowerShapeFunc(const CCacheKey& key) final {
+    return LowerShapeFuncInternal(key)->cached_func;
+  }
+
+  Map<String, IRModule> GetLoweredFunctions() {
+    Map<String, IRModule> lowered_functions;
+    for (const auto& it : cache_) {
+      auto source_func = it.first;
+      auto lowered_func = it.second;
+      auto target = source_func->target;
+
+      if (!lowered_functions.count(target->str())) {
+        lowered_functions.Set(target->str(), IRModule(Map<GlobalVar, BaseFunc>({})));
+      }
+
+      lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs);
+    }
+    return lowered_functions;
+  }
+
+  Array<tvm::runtime::Module> LowerExternalFunctions() {
+    Array<tvm::runtime::Module> ret;
+    std::unordered_map<std::string, std::string> cached_symbol;
+    std::vector<CCacheKey> cached_ext_funcs;
+    for (const auto& it : cache_) {
+      auto src_func = it.first->source_func;
+      ICHECK(src_func.defined());
+      if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
+        auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
+        ICHECK(code_gen.defined()) << "No external codegen is set";
+        std::string code_gen_name = code_gen.value();
+        cached_ext_funcs.push_back(it.first);
+
+        auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+        ICHECK(symbol_name.defined()) << "No external symbol is set for:\n"
+                                      << AsText(src_func, false);

Review comment:
       ```suggestion
           ICHECK(symbol_name.defined()) << "No external symbol is set for:\n"
                                         << AsText(src_func, false) << "\n" << "Functions with external codegen must have the "
                                         << attract::kGlobalSymbol << " attr set.";
   ```

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -0,0 +1,372 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "te_compiler.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "te_compiler_cache.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+using namespace tvm::relay::transform;
+
+TVM_REGISTER_OBJECT_TYPE(TECompilerNode);
+
+class TECompilerImpl : public TECompilerNode {
+ public:
+  // Lower the function.
+  CachedFunc Lower(const CCacheKey& key) { return LowerInternal(key)->cached_func; }
+
+  // For now, build one module per function.
+  PackedFunc JIT(const CCacheKey& key) final {
+    CCacheValue value = LowerInternal(key);
+    if (value->packed_func != nullptr) {
+      return value->packed_func;
+    }
+    auto m = build(value->cached_func->funcs, key->target, Target(nullptr));
+    value->packed_func = m.GetFunction(value->cached_func->prim_fn_var->name_hint);
+    return value->packed_func;
+  }
+
+  CachedFunc LowerShapeFunc(const CCacheKey& key) final {
+    return LowerShapeFuncInternal(key)->cached_func;
+  }
+
+  Map<String, IRModule> GetLoweredFunctions() {
+    Map<String, IRModule> lowered_functions;
+    for (const auto& it : cache_) {
+      auto source_func = it.first;
+      auto lowered_func = it.second;
+      auto target = source_func->target;
+
+      if (!lowered_functions.count(target->str())) {
+        lowered_functions.Set(target->str(), IRModule(Map<GlobalVar, BaseFunc>({})));
+      }
+
+      lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs);
+    }
+    return lowered_functions;
+  }
+
+  Array<tvm::runtime::Module> LowerExternalFunctions() {
+    Array<tvm::runtime::Module> ret;
+    std::unordered_map<std::string, std::string> cached_symbol;
+    std::vector<CCacheKey> cached_ext_funcs;
+    for (const auto& it : cache_) {
+      auto src_func = it.first->source_func;
+      ICHECK(src_func.defined());
+      if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
+        auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
+        ICHECK(code_gen.defined()) << "No external codegen is set";
+        std::string code_gen_name = code_gen.value();
+        cached_ext_funcs.push_back(it.first);
+
+        auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+        ICHECK(symbol_name.defined()) << "No external symbol is set for:\n"
+                                      << AsText(src_func, false);
+
+        std::string sn = symbol_name.value();
+        if (cached_symbol.count(sn)) {
+          cached_symbol[sn] = code_gen_name;
+        } else {
+          ICHECK_NE(sn, code_gen_name)
+              << "Found duplicated symbol: " << sn << " for: " << code_gen_name;
+        }
+
+        std::string ext_name = "relay.ext." + code_gen_name;
+        auto pf = tvm::runtime::Registry::Get(ext_name);
+        ICHECK(pf) << "Failed to find the codegen tool for " << ext_name << "\n";
+        // No need to keep compiler attribute at this point, functions have been
+        // extracted for specific codegen.
+        src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue<ObjectRef>());
+        runtime::Module ext_mod = (*pf)(src_func);
+
+        ICHECK(ext_mod.defined()) << "No external runtime is generated.";
+        ret.push_back(ext_mod);
+      }
+    }
+
+    // No need to cache external functions as we collected them all to create
+    // external runtime modules.
+    for (const auto& it : cached_ext_funcs) {
+      cache_.erase(it);
+    }
+    return ret;
+  }
+
+  void Clear() final { cache_.clear(); }
+
+  // List all items in the cache.
+  Array<ObjectRef> ListItems() {
+    std::lock_guard<std::mutex> lock(mutex_);
+    Array<ObjectRef> items;
+    for (auto& kv : cache_) {
+      items.push_back(kv.first);
+      items.push_back(kv.second);
+    }
+    return items;
+  }
+
+  /*!
+   * \brief Get the cache key of the function that is being lowered currently
+   * \return the cache key
+   */
+  CCacheKey GetCurrentCCacheKey() { return cur_ccache_key_; }
+
+ private:
+  // implement lowered func
+  CCacheValue LowerInternal(const CCacheKey& key) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    CCacheValue value;
+    auto it = cache_.find(key);
+    if (it != cache_.end()) {
+      it->second->use_count += 1;
+      if (it->second->cached_func.defined()) return it->second;
+      value = it->second;
+    } else {
+      value = CCacheValue(make_object<CCacheValueNode>());
+      value->use_count = 0;
+      if (!backend::IsCompileEngineCacheDisabled()) {
+        cache_[key] = value;
+      }
+    }
+    cur_ccache_key_ = key;
+
+    // No need to lower external functions for now. We will invoke the external
+    // codegen tool once and lower all functions together.
+    if (key->source_func->GetAttr<String>(attr::kCompiler).defined()) {
+      auto ir_module = IRModule();
+      const auto name_node = key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+      ICHECK(name_node.defined()) << "External function has not been attached a name yet.";
+      auto func_name = std::string(name_node.value());
+      auto target = Target("ext_dev");
+      auto global_var = GlobalVar(func_name);
+      global_var->checked_type_ = key->source_func->checked_type();
+      ir_module->Add(global_var, key->source_func);
+      value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule(), {}, ir_module);
+      return value;
+    }
+    // Enforce use the target.
+    With<Target> target_scope(key->target);
+
+    ICHECK(!value->cached_func.defined());
+    auto cfunc = PrimFuncFor(key->source_func, key->target,
+                             [&](std::string name) { return GetUniqueName(name, name_map_); });
+
+    // Skip lowering for device copy node.
+    const Expr body = (key->source_func)->body;
+    if (const CallNode* call_node = body.as<CallNode>()) {
+      if (call_node->attrs.as<DeviceCopyAttrs>()) {
+        value->cached_func = cfunc;
+        return value;
+      }
+    }
+
+    std::cout << "Input Size: " << cfunc->inputs.size() << std::endl;
+    std::cout << "Output Size: " << cfunc->outputs.size() << std::endl;
+    // NOTE: array will copy on write.
+    Array<te::Tensor> all_args = Array<te::Tensor>(cfunc->inputs);
+    for (te::Tensor arg : cfunc->outputs) {
+      all_args.push_back(arg);
+    }
+
+    std::cout << "Allargs Size: " << all_args.size() << std::endl;
+
+    using tvm::transform::PassContext;
+    With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
+
+    std::unordered_map<te::Tensor, tir::Buffer> binds;
+    auto func_name = cfunc->prim_fn_var->name_hint;
+    cfunc->funcs->Update(tvm::lower(cfunc->schedule, all_args, func_name, binds));
+    value->cached_func = cfunc;
+    return value;
+  }
+
+  // implement lowered shape func
+  CCacheValue LowerShapeFuncInternal(const CCacheKey& key) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    CCacheValue value;
+    auto it = shape_func_cache_.find(key);
+    if (it != shape_func_cache_.end()) {
+      it->second->use_count += 1;
+      if (it->second->cached_func.defined()) return it->second;
+      value = it->second;
+    } else {
+      value = CCacheValue(make_object<CCacheValueNode>());
+      value->use_count = 0;
+      shape_func_cache_[key] = value;
+    }
+    // Enforce use the target.
+    With<Target> target_scope(key->target);
+
+    ICHECK(!value->cached_func.defined());
+    auto cached_func = ShapeFuncFor(key->source_func, key->target, [&](std::string name) {
+      return GetUniqueName(name, name_map_);
+    });
+
+    value->cached_func = cached_func;
+    return value;
+  }
+
+  /*! \brief compiler cache lock*/
+  std::mutex mutex_;
+  /*! \brief internal name map to get an unique name */
+  std::unordered_map<std::string, int> name_map_;
+  /*! \brief internal compiler cache */
+  std::unordered_map<CCacheKey, CCacheValue> cache_;
+  /*! \brief internal compiler cache for shape funcs */
+  std::unordered_map<CCacheKey, CCacheValue> shape_func_cache_;
+  /*! \brief the cache key of the function that is being lowered currently*/
+  CCacheKey cur_ccache_key_;
+};
+
+TECompiler::TECompiler() {
+  auto object = make_object<TECompilerImpl>();
+  data_ = object;
+}
+
+class LowerTensorExpr : public ExprMutator {
+ public:
+  LowerTensorExpr(const IRModule& module, const TargetsMap& targets,
+                  const DeviceContextMap& device_ctx_map, TECompiler compiler)
+      : module_(module),
+        targets_(targets),
+        device_context_map_(device_ctx_map),
+        compiler_(compiler) {}
+
+  Expr VisitExpr_(const CallNode* call) override {
+    Call expr = GetRef<Call>(call);
+    Function func;
+
+    if (call->op.as<FunctionNode>()) {
+      func = GetRef<Function>(call->op.as<FunctionNode>());
+    } else {
+      return ExprMutator::VisitExpr_(call);
+    }
+
+    if (!func->HasNonzeroAttr(attr::kPrimitive)) {
+      // LOG(FATAL) << "TVM only support calls to primitive functions "
+      //           << "(i.e functions composed of fusable operator invocations)";
+      return ExprMutator::VisitExpr_(call);
+    }
+
+    // Process inputs.
+    Array<Expr> args;
+    for (size_t i = 0; i < expr->args.size(); i++) {
+      args.push_back(VisitExpr(expr->args[i]));
+    }
+
+    Target target;
+
+    if (func->GetAttr<String>(attr::kCompiler).defined()) {
+      target = Target("ext_dev");
+      CCacheKey key = CCacheKey(func, target);
+      CachedFunc ext_func = compiler_->Lower(key);
+      ICHECK(ext_func.defined()) << "External function is not defined.";

Review comment:
       Can you add the function name?

##########
File path: src/relay/backend/te_compiler_cache.cc
##########
@@ -0,0 +1,681 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "./te_compiler_cache.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+TVM_REGISTER_NODE_TYPE(LoweredOutputNode);
+TVM_REGISTER_NODE_TYPE(CachedFuncNode);
+TVM_REGISTER_NODE_TYPE(CCacheKeyNode);
+TVM_REGISTER_NODE_TYPE(CCacheValueNode);
+
+LoweredOutput::LoweredOutput(tvm::Array<te::Tensor> outputs, OpImplementation impl) {
+  auto n = make_object<LoweredOutputNode>();
+  n->outputs = std::move(outputs);
+  n->implementation = std::move(impl);
+  data_ = std::move(n);
+}
+
+CCacheKey::CCacheKey(Function source_func, Target target) {
+  auto n = make_object<CCacheKeyNode>();
+  n->source_func = std::move(source_func);
+  n->target = std::move(target);
+  data_ = std::move(n);
+}
+
+CachedFunc::CachedFunc(tvm::Target target, GlobalVar prim_fn_var, tvm::Array<te::Tensor> inputs,
+                       tvm::Array<te::Tensor> outputs, te::Schedule schedule,
+                       tvm::Array<Integer> shape_func_param_states, IRModule funcs) {
+  auto n = make_object<CachedFuncNode>();
+  n->target = target;
+  n->prim_fn_var = prim_fn_var;
+  n->inputs = inputs;
+  n->outputs = outputs;
+  n->schedule = schedule;
+  n->shape_func_param_states = shape_func_param_states;
+  n->funcs = funcs;
+  data_ = std::move(n);
+}
+
+Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
+  // for now, we always use int32 shape when possible
+  // even if the result of shape inference becomes int64.
+  Array<IndexExpr> res;
+  for (IndexExpr val : shape) {
+    const int64_t* pval = tir::as_const_int(val);
+    if (pval != nullptr) {
+#ifndef TVM_INDEX_DEFAULT_I64
+      ICHECK_LE(pval[0], std::numeric_limits<int32_t>::max());
+      ICHECK_GE(pval[0], std::numeric_limits<int32_t>::min());
+      res.push_back(IntImm(DataType::Int(32), *pval));
+#else
+      res.push_back(val);
+#endif  // TVM_INDEX_DEFAULT_I64
+    } else if (val->IsInstance<tir::AnyNode>()) {
+      res.push_back(val.as<tir::AnyNode>()->ToVar());
+    } else {
+      res.push_back(val);
+    }
+  }
+  return res;
+}
+
+// The getter to get schedule from compile engine.
+// Get schedule from functor.
+class ScheduleGetter : public backend::MemoizedExprTranslator<Array<te::Tensor>> {
+ public:
+  explicit ScheduleGetter(Target target)
+      : target_(target), device_copy_op_(Op::Get("device_copy")) {
+    // Whether to use auto_scheduler schedule.
+    use_auto_scheduler_ = backend::IsAutoSchedulerEnabled();
+  }
+
+  CachedFunc Create(const Function& prim_func, std::function<std::string(std::string)> renamer) {
+    Array<tvm::te::Tensor> fn_inputs;
+    for (Var param : prim_func->params) {
+      Array<tvm::te::Tensor> inputs;
+      if (const auto* ttype = param->checked_type().as<TensorTypeNode>()) {
+        tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype);
+        fn_inputs.push_back(tensor);
+        inputs.push_back(tensor);
+      } else {
+        // flatten tuple of tensor type.
+        const auto* tuple_type = param->type_as<TupleTypeNode>();
+        for (Type field : tuple_type->fields) {
+          const auto* ttype = field.as<TensorTypeNode>();
+          // TODO(@icemelon): Allow recursive tuple
+          ICHECK(ttype != nullptr);
+          tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype);
+          fn_inputs.push_back(tensor);
+          inputs.push_back(tensor);
+        }
+      }
+      memo_[param] = inputs;
+    }
+    readable_name_stream_ << "fused";
+    auto outputs = this->VisitExpr(prim_func->body);
+    auto candidate_name = readable_name_stream_.str();
+    constexpr static size_t kMaxFuncNameLength = 80;
+    if (candidate_name.size() > kMaxFuncNameLength) {
+      std::stringstream truncated_name;
+      truncated_name << candidate_name.substr(0, kMaxFuncNameLength);
+      truncated_name << "_" << std::hash<std::string>{}(candidate_name) << "_";
+      candidate_name = truncated_name.str();
+    }
+
+    ICHECK(anchor_op_.defined());
+    // Fusion over tupled results may leave identity relationships
+    // between inputs and outputs, and those should not be scheduled.
+    // Hence schedule only non PlaceholderOp outputs.
+    tvm::Array<te::Tensor> tensor_outs;
+    for (const auto& tensor : outputs) {
+      if (!tensor->op.as<te::PlaceholderOpNode>()) {
+        tensor_outs.push_back(tensor);
+      }
+    }
+
+    te::Schedule schedule;
+    // No need to register schedule for device copy op.
+    if (anchor_attrs_.as<DeviceCopyAttrs>() == nullptr) {
+      if (use_auto_scheduler_) {
+        const auto* fauto_schedule =
+            runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute");
+        ICHECK(fauto_schedule != nullptr)
+            << "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered";
+        ObjectRef obj = (*fauto_schedule)(tensor_outs);
+        if (obj.defined()) {
+          schedule = Downcast<te::Schedule>(obj);
+        }
+      }
+
+      // Use TOPI schdule if user specificed, or the function has no auto_scheduler schedule.
+      if (!schedule.defined()) {
+        ICHECK(anchor_implementation_.defined());
+        schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_);
+      }
+      for (const auto& scalar : scalars_) {
+        if (schedule->Contain(scalar)) {
+          schedule[scalar].compute_inline();
+        }
+      }
+    }
+
+    auto prim_fn_var = GlobalVar(candidate_name);
+    prim_fn_var->checked_type_ = prim_func->checked_type();
+    return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, {});
+  }
+
+  Array<te::Tensor> VisitExpr_(const VarNode* op) final {
+    LOG(FATAL) << "Free variable " << op->name_hint();

Review comment:
       Could you make this a little more detailed? Maybe even "Unexpected free variable " << op->name_hint << ", expected " ....

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -0,0 +1,372 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "te_compiler.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "te_compiler_cache.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+using namespace tvm::relay::transform;
+
+TVM_REGISTER_OBJECT_TYPE(TECompilerNode);
+
+class TECompilerImpl : public TECompilerNode {
+ public:
+  // Lower the function.
+  CachedFunc Lower(const CCacheKey& key) { return LowerInternal(key)->cached_func; }
+
+  // For now, build one module per function.
+  PackedFunc JIT(const CCacheKey& key) final {
+    CCacheValue value = LowerInternal(key);
+    if (value->packed_func != nullptr) {
+      return value->packed_func;
+    }
+    auto m = build(value->cached_func->funcs, key->target, Target(nullptr));
+    value->packed_func = m.GetFunction(value->cached_func->prim_fn_var->name_hint);
+    return value->packed_func;
+  }
+
+  CachedFunc LowerShapeFunc(const CCacheKey& key) final {
+    return LowerShapeFuncInternal(key)->cached_func;
+  }
+
+  Map<String, IRModule> GetLoweredFunctions() {
+    Map<String, IRModule> lowered_functions;
+    for (const auto& it : cache_) {
+      auto source_func = it.first;
+      auto lowered_func = it.second;
+      auto target = source_func->target;
+
+      if (!lowered_functions.count(target->str())) {
+        lowered_functions.Set(target->str(), IRModule(Map<GlobalVar, BaseFunc>({})));
+      }
+
+      lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs);
+    }
+    return lowered_functions;
+  }
+
+  Array<tvm::runtime::Module> LowerExternalFunctions() {
+    Array<tvm::runtime::Module> ret;
+    std::unordered_map<std::string, std::string> cached_symbol;
+    std::vector<CCacheKey> cached_ext_funcs;
+    for (const auto& it : cache_) {
+      auto src_func = it.first->source_func;
+      ICHECK(src_func.defined());
+      if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
+        auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
+        ICHECK(code_gen.defined()) << "No external codegen is set";

Review comment:
       Is this check necessary? Didn't line 100 already check if `code_gen` (`src_func->GetAttr<String>(attr::kCompiler)`) was defined?

##########
File path: src/relay/backend/graph_runtime_codegen.cc
##########
@@ -349,65 +380,47 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator<std::vector<G
     return AddNode(node, GetRef<Expr>(op));
   }
 
-  std::vector<GraphNodeRef> VisitExpr_(const CallNode* op) override {
-    Expr expr = GetRef<Expr>(op);
-    Function func;
-    if (op->op.as<OpNode>()) {
-      LOG(FATAL) << "Operators should be transformed away; try applying"
-                 << "the fuse_ops transformation to the expression.";
-    } else if (op->op.as<GlobalVarNode>()) {
-      LOG(FATAL) << "Not implemented";
-    } else if (op->op.as<FunctionNode>()) {
-      func = GetRef<Function>(op->op.as<FunctionNode>());
-    } else {
-      LOG(FATAL) << "TVM runtime does not support calls to " << op->op->GetTypeKey();
-    }
-    if (!func->HasNonzeroAttr(attr::kPrimitive)) {
-      LOG(FATAL) << "TVM only support calls to primitive functions "
-                 << "(i.e functions composed of fusable operator invocations)";
-    }
+  std::vector<GraphNodeRef> VisitExpr_(const CallNode* call_node) override {
+    relay::Call call = GetRef<Call>(call_node);
 
-    auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey");
-    auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower");
-    Target target;
-    // Handle external function
-    if (func->GetAttr<String>(attr::kCompiler).defined()) {
-      target = Target("ext_dev");
-      CCacheKey key = (*pf0)(func, target);
-      CachedFunc ext_func = (*pf1)(compile_engine_, key);
-      ICHECK(ext_func.defined()) << "External function is not defined.";
-      UpdateConstants(func, &params_);
-      return GraphAddCallNode(op, ext_func->func_name, ext_func->func_name);
-    }
+    if (auto global_node = call->op.as<GlobalVarNode>()) {
+      auto prim_fn_name = global_node->name_hint;
 
-    ICHECK_GE(storage_device_map_.count(expr), 0);
-    auto& device_type = storage_device_map_[expr][1];
-    auto call_dev_type = device_type[0]->value;
-    // Normal Relay Function
-    if (targets_.size() == 1) {
-      // homogeneous execution.
-      const auto& it = targets_.begin();
-      target = (*it).second;
-    } else {
-      // heterogeneous execution.
-      std::string call_dev_name;
-      if (call_dev_type == 0) {
-        call_dev_name = "llvm";
+      Target target;
+
+      // // Handle external function
+      // if (func->GetAttr<String>(attr::kCompiler).defined()) {
+      //   UpdateConstants(func, &params_);
+      //   return GraphAddCallNode(call_node, prim_fn_name, prim_fn_name);
+      // }
+
+      ICHECK_GE(storage_device_map_.count(call), 0);
+      auto& device_type = storage_device_map_[call][1];
+      auto call_dev_type = device_type[0]->value;
+      // Normal Relay Function
+      if (targets_.size() == 1) {
+        // homogeneous execution.
+        const auto& it = targets_.begin();
+        target = (*it).second;
       } else {
-        call_dev_name = runtime::DeviceName(call_dev_type);
-      }
-      if (targets_.count(call_dev_type) == 0) {
-        LOG(FATAL) << "No target is provided for device " << call_dev_name;
+        // heterogeneous execution.
+        std::string call_dev_name;
+        if (call_dev_type == 0) {
+          call_dev_name = "llvm";
+        } else {
+          call_dev_name = runtime::DeviceName(call_dev_type);
+        }
+        if (targets_.count(call_dev_type) == 0) {
+          LOG(FATAL) << "No target is provided for device " << call_dev_name;
+        }
+        target = targets_[call_dev_type];
       }
-      target = targets_[call_dev_type];
-    }
-    CCacheKey key = (*pf0)(func, target);
-    CachedFunc lowered_func = (*pf1)(compile_engine_, key);
-    if (!lowered_funcs_.count(target->str())) {
-      lowered_funcs_[target->str()] = IRModule(Map<GlobalVar, BaseFunc>({}));
+
+      return GraphAddCallNode(call_node, _GetUniqueName(prim_fn_name), prim_fn_name);
+    } else {
+      LOG(FATAL) << "BadCase: " << PrettyPrint(call) << std::endl;

Review comment:
       ```suggestion
         LOG(FATAL) << "Graph runtime codegen can only handle calls to global functions, but it got a " call->GetTypeKey() << " (should be GlobalVarNode). This is what was provided: " << PrettyPrint(call);
   ```

##########
File path: src/driver/driver_api.cc
##########
@@ -244,14 +244,17 @@ std::pair<IRModule, IRModule> SplitDevHostFuncs(IRModule mod_mixed, const Target
   }
 
   if (target->kind->device_type == kDLCPU && target_host == target) {
-    ICHECK(mdevice->functions.empty()) << "No device code should be generated when target "
-                                       << "and host_target are both llvm target."
-                                       << "\n";
+    // We need to relax this check for just TIR functions.

Review comment:
       Maybe add a todo to fix this up?

##########
File path: src/relay/backend/graph_runtime_codegen.cc
##########
@@ -349,65 +380,47 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator<std::vector<G
     return AddNode(node, GetRef<Expr>(op));
   }
 
-  std::vector<GraphNodeRef> VisitExpr_(const CallNode* op) override {
-    Expr expr = GetRef<Expr>(op);
-    Function func;
-    if (op->op.as<OpNode>()) {
-      LOG(FATAL) << "Operators should be transformed away; try applying"
-                 << "the fuse_ops transformation to the expression.";
-    } else if (op->op.as<GlobalVarNode>()) {
-      LOG(FATAL) << "Not implemented";
-    } else if (op->op.as<FunctionNode>()) {
-      func = GetRef<Function>(op->op.as<FunctionNode>());
-    } else {
-      LOG(FATAL) << "TVM runtime does not support calls to " << op->op->GetTypeKey();
-    }
-    if (!func->HasNonzeroAttr(attr::kPrimitive)) {
-      LOG(FATAL) << "TVM only support calls to primitive functions "
-                 << "(i.e functions composed of fusable operator invocations)";
-    }
+  std::vector<GraphNodeRef> VisitExpr_(const CallNode* call_node) override {
+    relay::Call call = GetRef<Call>(call_node);
 
-    auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey");
-    auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower");
-    Target target;
-    // Handle external function
-    if (func->GetAttr<String>(attr::kCompiler).defined()) {
-      target = Target("ext_dev");
-      CCacheKey key = (*pf0)(func, target);
-      CachedFunc ext_func = (*pf1)(compile_engine_, key);
-      ICHECK(ext_func.defined()) << "External function is not defined.";
-      UpdateConstants(func, &params_);
-      return GraphAddCallNode(op, ext_func->func_name, ext_func->func_name);
-    }
+    if (auto global_node = call->op.as<GlobalVarNode>()) {
+      auto prim_fn_name = global_node->name_hint;
 
-    ICHECK_GE(storage_device_map_.count(expr), 0);
-    auto& device_type = storage_device_map_[expr][1];
-    auto call_dev_type = device_type[0]->value;
-    // Normal Relay Function
-    if (targets_.size() == 1) {
-      // homogeneous execution.
-      const auto& it = targets_.begin();
-      target = (*it).second;
-    } else {
-      // heterogeneous execution.
-      std::string call_dev_name;
-      if (call_dev_type == 0) {
-        call_dev_name = "llvm";
+      Target target;
+
+      // // Handle external function
+      // if (func->GetAttr<String>(attr::kCompiler).defined()) {
+      //   UpdateConstants(func, &params_);
+      //   return GraphAddCallNode(call_node, prim_fn_name, prim_fn_name);
+      // }
+
+      ICHECK_GE(storage_device_map_.count(call), 0);

Review comment:
       ```suggestion
         ICHECK_GE(storage_device_map_.count(call), 0) << "Could not find a storage device for " << prim_fn_name << ". This could be cause my a error in GraphPlanMemory.";
   ```

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -0,0 +1,372 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "te_compiler.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "te_compiler_cache.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+using namespace tvm::relay::transform;
+
+TVM_REGISTER_OBJECT_TYPE(TECompilerNode);
+
+class TECompilerImpl : public TECompilerNode {
+ public:
+  // Lower the function.
+  CachedFunc Lower(const CCacheKey& key) { return LowerInternal(key)->cached_func; }
+
+  // For now, build one module per function.
+  PackedFunc JIT(const CCacheKey& key) final {
+    CCacheValue value = LowerInternal(key);
+    if (value->packed_func != nullptr) {
+      return value->packed_func;
+    }
+    auto m = build(value->cached_func->funcs, key->target, Target(nullptr));
+    value->packed_func = m.GetFunction(value->cached_func->prim_fn_var->name_hint);
+    return value->packed_func;
+  }
+
+  CachedFunc LowerShapeFunc(const CCacheKey& key) final {
+    return LowerShapeFuncInternal(key)->cached_func;
+  }
+
+  Map<String, IRModule> GetLoweredFunctions() {
+    Map<String, IRModule> lowered_functions;
+    for (const auto& it : cache_) {
+      auto source_func = it.first;
+      auto lowered_func = it.second;
+      auto target = source_func->target;
+
+      if (!lowered_functions.count(target->str())) {
+        lowered_functions.Set(target->str(), IRModule(Map<GlobalVar, BaseFunc>({})));
+      }
+
+      lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs);
+    }
+    return lowered_functions;
+  }
+
+  Array<tvm::runtime::Module> LowerExternalFunctions() {
+    Array<tvm::runtime::Module> ret;
+    std::unordered_map<std::string, std::string> cached_symbol;
+    std::vector<CCacheKey> cached_ext_funcs;
+    for (const auto& it : cache_) {
+      auto src_func = it.first->source_func;
+      ICHECK(src_func.defined());
+      if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
+        auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
+        ICHECK(code_gen.defined()) << "No external codegen is set";
+        std::string code_gen_name = code_gen.value();
+        cached_ext_funcs.push_back(it.first);
+
+        auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+        ICHECK(symbol_name.defined()) << "No external symbol is set for:\n"
+                                      << AsText(src_func, false);
+
+        std::string sn = symbol_name.value();
+        if (cached_symbol.count(sn)) {
+          cached_symbol[sn] = code_gen_name;
+        } else {
+          ICHECK_NE(sn, code_gen_name)
+              << "Found duplicated symbol: " << sn << " for: " << code_gen_name;
+        }
+
+        std::string ext_name = "relay.ext." + code_gen_name;
+        auto pf = tvm::runtime::Registry::Get(ext_name);
+        ICHECK(pf) << "Failed to find the codegen tool for " << ext_name << "\n";
+        // No need to keep compiler attribute at this point, functions have been
+        // extracted for specific codegen.
+        src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue<ObjectRef>());
+        runtime::Module ext_mod = (*pf)(src_func);
+
+        ICHECK(ext_mod.defined()) << "No external runtime is generated.";
+        ret.push_back(ext_mod);
+      }
+    }
+
+    // No need to cache external functions as we collected them all to create
+    // external runtime modules.
+    for (const auto& it : cached_ext_funcs) {
+      cache_.erase(it);
+    }
+    return ret;
+  }
+
+  void Clear() final { cache_.clear(); }
+
+  // List all items in the cache.
+  Array<ObjectRef> ListItems() {
+    std::lock_guard<std::mutex> lock(mutex_);
+    Array<ObjectRef> items;
+    for (auto& kv : cache_) {
+      items.push_back(kv.first);
+      items.push_back(kv.second);
+    }
+    return items;
+  }
+
+  /*!
+   * \brief Get the cache key of the function that is being lowered currently
+   * \return the cache key
+   */
+  CCacheKey GetCurrentCCacheKey() { return cur_ccache_key_; }
+
+ private:
+  // implement lowered func
+  CCacheValue LowerInternal(const CCacheKey& key) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    CCacheValue value;
+    auto it = cache_.find(key);
+    if (it != cache_.end()) {
+      it->second->use_count += 1;
+      if (it->second->cached_func.defined()) return it->second;
+      value = it->second;
+    } else {
+      value = CCacheValue(make_object<CCacheValueNode>());
+      value->use_count = 0;
+      if (!backend::IsCompileEngineCacheDisabled()) {
+        cache_[key] = value;
+      }
+    }
+    cur_ccache_key_ = key;
+
+    // No need to lower external functions for now. We will invoke the external
+    // codegen tool once and lower all functions together.
+    if (key->source_func->GetAttr<String>(attr::kCompiler).defined()) {
+      auto ir_module = IRModule();
+      const auto name_node = key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+      ICHECK(name_node.defined()) << "External function has not been attached a name yet.";
+      auto func_name = std::string(name_node.value());
+      auto target = Target("ext_dev");
+      auto global_var = GlobalVar(func_name);
+      global_var->checked_type_ = key->source_func->checked_type();
+      ir_module->Add(global_var, key->source_func);
+      value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule(), {}, ir_module);
+      return value;
+    }
+    // Enforce use the target.
+    With<Target> target_scope(key->target);
+
+    ICHECK(!value->cached_func.defined());
+    auto cfunc = PrimFuncFor(key->source_func, key->target,
+                             [&](std::string name) { return GetUniqueName(name, name_map_); });
+
+    // Skip lowering for device copy node.
+    const Expr body = (key->source_func)->body;
+    if (const CallNode* call_node = body.as<CallNode>()) {
+      if (call_node->attrs.as<DeviceCopyAttrs>()) {
+        value->cached_func = cfunc;
+        return value;
+      }
+    }
+
+    std::cout << "Input Size: " << cfunc->inputs.size() << std::endl;
+    std::cout << "Output Size: " << cfunc->outputs.size() << std::endl;

Review comment:
       remove?

##########
File path: src/relay/backend/graph_runtime_codegen.cc
##########
@@ -349,65 +380,47 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator<std::vector<G
     return AddNode(node, GetRef<Expr>(op));
   }
 
-  std::vector<GraphNodeRef> VisitExpr_(const CallNode* op) override {
-    Expr expr = GetRef<Expr>(op);
-    Function func;
-    if (op->op.as<OpNode>()) {
-      LOG(FATAL) << "Operators should be transformed away; try applying"
-                 << "the fuse_ops transformation to the expression.";
-    } else if (op->op.as<GlobalVarNode>()) {
-      LOG(FATAL) << "Not implemented";
-    } else if (op->op.as<FunctionNode>()) {
-      func = GetRef<Function>(op->op.as<FunctionNode>());
-    } else {
-      LOG(FATAL) << "TVM runtime does not support calls to " << op->op->GetTypeKey();
-    }
-    if (!func->HasNonzeroAttr(attr::kPrimitive)) {
-      LOG(FATAL) << "TVM only support calls to primitive functions "
-                 << "(i.e functions composed of fusable operator invocations)";
-    }
+  std::vector<GraphNodeRef> VisitExpr_(const CallNode* call_node) override {
+    relay::Call call = GetRef<Call>(call_node);
 
-    auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey");
-    auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower");
-    Target target;
-    // Handle external function
-    if (func->GetAttr<String>(attr::kCompiler).defined()) {
-      target = Target("ext_dev");
-      CCacheKey key = (*pf0)(func, target);
-      CachedFunc ext_func = (*pf1)(compile_engine_, key);
-      ICHECK(ext_func.defined()) << "External function is not defined.";
-      UpdateConstants(func, &params_);
-      return GraphAddCallNode(op, ext_func->func_name, ext_func->func_name);
-    }
+    if (auto global_node = call->op.as<GlobalVarNode>()) {
+      auto prim_fn_name = global_node->name_hint;
 
-    ICHECK_GE(storage_device_map_.count(expr), 0);
-    auto& device_type = storage_device_map_[expr][1];
-    auto call_dev_type = device_type[0]->value;
-    // Normal Relay Function
-    if (targets_.size() == 1) {
-      // homogeneous execution.
-      const auto& it = targets_.begin();
-      target = (*it).second;
-    } else {
-      // heterogeneous execution.
-      std::string call_dev_name;
-      if (call_dev_type == 0) {
-        call_dev_name = "llvm";
+      Target target;
+
+      // // Handle external function
+      // if (func->GetAttr<String>(attr::kCompiler).defined()) {
+      //   UpdateConstants(func, &params_);
+      //   return GraphAddCallNode(call_node, prim_fn_name, prim_fn_name);
+      // }

Review comment:
       Delete?

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -0,0 +1,372 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "te_compiler.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "te_compiler_cache.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+using namespace tvm::relay::transform;
+
+TVM_REGISTER_OBJECT_TYPE(TECompilerNode);
+
+class TECompilerImpl : public TECompilerNode {
+ public:
+  // Lower the function.
+  CachedFunc Lower(const CCacheKey& key) { return LowerInternal(key)->cached_func; }
+
+  // For now, build one module per function.
+  PackedFunc JIT(const CCacheKey& key) final {
+    CCacheValue value = LowerInternal(key);
+    if (value->packed_func != nullptr) {
+      return value->packed_func;
+    }
+    auto m = build(value->cached_func->funcs, key->target, Target(nullptr));
+    value->packed_func = m.GetFunction(value->cached_func->prim_fn_var->name_hint);
+    return value->packed_func;
+  }
+
+  CachedFunc LowerShapeFunc(const CCacheKey& key) final {
+    return LowerShapeFuncInternal(key)->cached_func;
+  }
+
+  Map<String, IRModule> GetLoweredFunctions() {
+    Map<String, IRModule> lowered_functions;
+    for (const auto& it : cache_) {
+      auto source_func = it.first;
+      auto lowered_func = it.second;
+      auto target = source_func->target;
+
+      if (!lowered_functions.count(target->str())) {
+        lowered_functions.Set(target->str(), IRModule(Map<GlobalVar, BaseFunc>({})));
+      }
+
+      lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs);
+    }
+    return lowered_functions;
+  }
+
+  Array<tvm::runtime::Module> LowerExternalFunctions() {
+    Array<tvm::runtime::Module> ret;
+    std::unordered_map<std::string, std::string> cached_symbol;
+    std::vector<CCacheKey> cached_ext_funcs;
+    for (const auto& it : cache_) {
+      auto src_func = it.first->source_func;
+      ICHECK(src_func.defined());
+      if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
+        auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
+        ICHECK(code_gen.defined()) << "No external codegen is set";
+        std::string code_gen_name = code_gen.value();
+        cached_ext_funcs.push_back(it.first);
+
+        auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+        ICHECK(symbol_name.defined()) << "No external symbol is set for:\n"
+                                      << AsText(src_func, false);
+
+        std::string sn = symbol_name.value();
+        if (cached_symbol.count(sn)) {
+          cached_symbol[sn] = code_gen_name;
+        } else {
+          ICHECK_NE(sn, code_gen_name)
+              << "Found duplicated symbol: " << sn << " for: " << code_gen_name;
+        }
+
+        std::string ext_name = "relay.ext." + code_gen_name;
+        auto pf = tvm::runtime::Registry::Get(ext_name);
+        ICHECK(pf) << "Failed to find the codegen tool for " << ext_name << "\n";
+        // No need to keep compiler attribute at this point, functions have been
+        // extracted for specific codegen.
+        src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue<ObjectRef>());
+        runtime::Module ext_mod = (*pf)(src_func);
+
+        ICHECK(ext_mod.defined()) << "No external runtime is generated.";

Review comment:
       ```suggestion
           ICHECK(ext_mod.defined()) << "No external runtime was generated by " << ext_name << ".";
   ```

##########
File path: src/relay/backend/te_compiler_cache.cc
##########
@@ -0,0 +1,681 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "./te_compiler_cache.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+TVM_REGISTER_NODE_TYPE(LoweredOutputNode);
+TVM_REGISTER_NODE_TYPE(CachedFuncNode);
+TVM_REGISTER_NODE_TYPE(CCacheKeyNode);
+TVM_REGISTER_NODE_TYPE(CCacheValueNode);
+
+LoweredOutput::LoweredOutput(tvm::Array<te::Tensor> outputs, OpImplementation impl) {
+  auto n = make_object<LoweredOutputNode>();
+  n->outputs = std::move(outputs);
+  n->implementation = std::move(impl);
+  data_ = std::move(n);
+}
+
+CCacheKey::CCacheKey(Function source_func, Target target) {
+  auto n = make_object<CCacheKeyNode>();
+  n->source_func = std::move(source_func);
+  n->target = std::move(target);
+  data_ = std::move(n);
+}
+
+CachedFunc::CachedFunc(tvm::Target target, GlobalVar prim_fn_var, tvm::Array<te::Tensor> inputs,
+                       tvm::Array<te::Tensor> outputs, te::Schedule schedule,
+                       tvm::Array<Integer> shape_func_param_states, IRModule funcs) {
+  auto n = make_object<CachedFuncNode>();
+  n->target = target;
+  n->prim_fn_var = prim_fn_var;
+  n->inputs = inputs;
+  n->outputs = outputs;
+  n->schedule = schedule;
+  n->shape_func_param_states = shape_func_param_states;
+  n->funcs = funcs;
+  data_ = std::move(n);
+}
+
+Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
+  // for now, we always use int32 shape when possible
+  // even if the result of shape inference becomes int64.
+  Array<IndexExpr> res;
+  for (IndexExpr val : shape) {
+    const int64_t* pval = tir::as_const_int(val);
+    if (pval != nullptr) {
+#ifndef TVM_INDEX_DEFAULT_I64
+      ICHECK_LE(pval[0], std::numeric_limits<int32_t>::max());

Review comment:
       Can we have an error message on these?

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -0,0 +1,372 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "te_compiler.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "te_compiler_cache.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+using namespace tvm::relay::transform;
+
+TVM_REGISTER_OBJECT_TYPE(TECompilerNode);
+
+class TECompilerImpl : public TECompilerNode {
+ public:
+  // Lower the function.
+  CachedFunc Lower(const CCacheKey& key) { return LowerInternal(key)->cached_func; }
+
+  // For now, build one module per function.
+  PackedFunc JIT(const CCacheKey& key) final {
+    CCacheValue value = LowerInternal(key);
+    if (value->packed_func != nullptr) {
+      return value->packed_func;
+    }
+    auto m = build(value->cached_func->funcs, key->target, Target(nullptr));
+    value->packed_func = m.GetFunction(value->cached_func->prim_fn_var->name_hint);
+    return value->packed_func;
+  }
+
+  CachedFunc LowerShapeFunc(const CCacheKey& key) final {
+    return LowerShapeFuncInternal(key)->cached_func;
+  }
+
+  Map<String, IRModule> GetLoweredFunctions() {
+    Map<String, IRModule> lowered_functions;
+    for (const auto& it : cache_) {
+      auto source_func = it.first;
+      auto lowered_func = it.second;
+      auto target = source_func->target;
+
+      if (!lowered_functions.count(target->str())) {
+        lowered_functions.Set(target->str(), IRModule(Map<GlobalVar, BaseFunc>({})));
+      }
+
+      lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs);
+    }
+    return lowered_functions;
+  }
+
+  Array<tvm::runtime::Module> LowerExternalFunctions() {
+    Array<tvm::runtime::Module> ret;
+    std::unordered_map<std::string, std::string> cached_symbol;
+    std::vector<CCacheKey> cached_ext_funcs;
+    for (const auto& it : cache_) {
+      auto src_func = it.first->source_func;
+      ICHECK(src_func.defined());
+      if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
+        auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
+        ICHECK(code_gen.defined()) << "No external codegen is set";
+        std::string code_gen_name = code_gen.value();
+        cached_ext_funcs.push_back(it.first);
+
+        auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+        ICHECK(symbol_name.defined()) << "No external symbol is set for:\n"
+                                      << AsText(src_func, false);
+
+        std::string sn = symbol_name.value();
+        if (cached_symbol.count(sn)) {
+          cached_symbol[sn] = code_gen_name;
+        } else {
+          ICHECK_NE(sn, code_gen_name)
+              << "Found duplicated symbol: " << sn << " for: " << code_gen_name;
+        }
+
+        std::string ext_name = "relay.ext." + code_gen_name;
+        auto pf = tvm::runtime::Registry::Get(ext_name);
+        ICHECK(pf) << "Failed to find the codegen tool for " << ext_name << "\n";
+        // No need to keep compiler attribute at this point, functions have been
+        // extracted for specific codegen.
+        src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue<ObjectRef>());
+        runtime::Module ext_mod = (*pf)(src_func);
+
+        ICHECK(ext_mod.defined()) << "No external runtime is generated.";
+        ret.push_back(ext_mod);
+      }
+    }
+
+    // No need to cache external functions as we collected them all to create
+    // external runtime modules.
+    for (const auto& it : cached_ext_funcs) {
+      cache_.erase(it);
+    }
+    return ret;
+  }
+
+  void Clear() final { cache_.clear(); }
+
+  // List all items in the cache.
+  Array<ObjectRef> ListItems() {
+    std::lock_guard<std::mutex> lock(mutex_);
+    Array<ObjectRef> items;
+    for (auto& kv : cache_) {
+      items.push_back(kv.first);
+      items.push_back(kv.second);
+    }
+    return items;
+  }
+
+  /*!
+   * \brief Get the cache key of the function that is being lowered currently
+   * \return the cache key
+   */
+  CCacheKey GetCurrentCCacheKey() { return cur_ccache_key_; }
+
+ private:
+  // implement lowered func
+  CCacheValue LowerInternal(const CCacheKey& key) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    CCacheValue value;
+    auto it = cache_.find(key);
+    if (it != cache_.end()) {
+      it->second->use_count += 1;
+      if (it->second->cached_func.defined()) return it->second;
+      value = it->second;
+    } else {
+      value = CCacheValue(make_object<CCacheValueNode>());
+      value->use_count = 0;
+      if (!backend::IsCompileEngineCacheDisabled()) {
+        cache_[key] = value;
+      }
+    }
+    cur_ccache_key_ = key;
+
+    // No need to lower external functions for now. We will invoke the external
+    // codegen tool once and lower all functions together.
+    if (key->source_func->GetAttr<String>(attr::kCompiler).defined()) {
+      auto ir_module = IRModule();
+      const auto name_node = key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+      ICHECK(name_node.defined()) << "External function has not been attached a name yet.";
+      auto func_name = std::string(name_node.value());
+      auto target = Target("ext_dev");
+      auto global_var = GlobalVar(func_name);
+      global_var->checked_type_ = key->source_func->checked_type();
+      ir_module->Add(global_var, key->source_func);
+      value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule(), {}, ir_module);
+      return value;
+    }
+    // Enforce use the target.
+    With<Target> target_scope(key->target);
+
+    ICHECK(!value->cached_func.defined());
+    auto cfunc = PrimFuncFor(key->source_func, key->target,
+                             [&](std::string name) { return GetUniqueName(name, name_map_); });
+
+    // Skip lowering for device copy node.
+    const Expr body = (key->source_func)->body;
+    if (const CallNode* call_node = body.as<CallNode>()) {
+      if (call_node->attrs.as<DeviceCopyAttrs>()) {
+        value->cached_func = cfunc;
+        return value;
+      }
+    }
+
+    std::cout << "Input Size: " << cfunc->inputs.size() << std::endl;
+    std::cout << "Output Size: " << cfunc->outputs.size() << std::endl;
+    // NOTE: array will copy on write.
+    Array<te::Tensor> all_args = Array<te::Tensor>(cfunc->inputs);
+    for (te::Tensor arg : cfunc->outputs) {
+      all_args.push_back(arg);
+    }
+
+    std::cout << "Allargs Size: " << all_args.size() << std::endl;
+
+    using tvm::transform::PassContext;
+    With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
+
+    std::unordered_map<te::Tensor, tir::Buffer> binds;
+    auto func_name = cfunc->prim_fn_var->name_hint;
+    cfunc->funcs->Update(tvm::lower(cfunc->schedule, all_args, func_name, binds));
+    value->cached_func = cfunc;
+    return value;
+  }
+
+  // implement lowered shape func
+  CCacheValue LowerShapeFuncInternal(const CCacheKey& key) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    CCacheValue value;
+    auto it = shape_func_cache_.find(key);
+    if (it != shape_func_cache_.end()) {
+      it->second->use_count += 1;
+      if (it->second->cached_func.defined()) return it->second;
+      value = it->second;
+    } else {
+      value = CCacheValue(make_object<CCacheValueNode>());
+      value->use_count = 0;
+      shape_func_cache_[key] = value;
+    }
+    // Enforce use the target.
+    With<Target> target_scope(key->target);
+
+    ICHECK(!value->cached_func.defined());
+    auto cached_func = ShapeFuncFor(key->source_func, key->target, [&](std::string name) {
+      return GetUniqueName(name, name_map_);
+    });
+
+    value->cached_func = cached_func;
+    return value;
+  }
+
+  /*! \brief compiler cache lock*/
+  std::mutex mutex_;
+  /*! \brief internal name map to get an unique name */
+  std::unordered_map<std::string, int> name_map_;
+  /*! \brief internal compiler cache */
+  std::unordered_map<CCacheKey, CCacheValue> cache_;
+  /*! \brief internal compiler cache for shape funcs */
+  std::unordered_map<CCacheKey, CCacheValue> shape_func_cache_;
+  /*! \brief the cache key of the function that is being lowered currently*/
+  CCacheKey cur_ccache_key_;
+};
+
+TECompiler::TECompiler() {
+  auto object = make_object<TECompilerImpl>();
+  data_ = object;
+}
+
+class LowerTensorExpr : public ExprMutator {
+ public:
+  LowerTensorExpr(const IRModule& module, const TargetsMap& targets,
+                  const DeviceContextMap& device_ctx_map, TECompiler compiler)
+      : module_(module),
+        targets_(targets),
+        device_context_map_(device_ctx_map),
+        compiler_(compiler) {}
+
+  Expr VisitExpr_(const CallNode* call) override {
+    Call expr = GetRef<Call>(call);
+    Function func;
+
+    if (call->op.as<FunctionNode>()) {
+      func = GetRef<Function>(call->op.as<FunctionNode>());
+    } else {
+      return ExprMutator::VisitExpr_(call);
+    }
+
+    if (!func->HasNonzeroAttr(attr::kPrimitive)) {
+      // LOG(FATAL) << "TVM only support calls to primitive functions "
+      //           << "(i.e functions composed of fusable operator invocations)";

Review comment:
       Delete or uncomment?

##########
File path: src/relay/backend/te_compiler_cache.cc
##########
@@ -0,0 +1,681 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "./te_compiler_cache.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+TVM_REGISTER_NODE_TYPE(LoweredOutputNode);
+TVM_REGISTER_NODE_TYPE(CachedFuncNode);
+TVM_REGISTER_NODE_TYPE(CCacheKeyNode);
+TVM_REGISTER_NODE_TYPE(CCacheValueNode);
+
+LoweredOutput::LoweredOutput(tvm::Array<te::Tensor> outputs, OpImplementation impl) {
+  auto n = make_object<LoweredOutputNode>();
+  n->outputs = std::move(outputs);
+  n->implementation = std::move(impl);
+  data_ = std::move(n);
+}
+
+CCacheKey::CCacheKey(Function source_func, Target target) {
+  auto n = make_object<CCacheKeyNode>();
+  n->source_func = std::move(source_func);
+  n->target = std::move(target);
+  data_ = std::move(n);
+}
+
+CachedFunc::CachedFunc(tvm::Target target, GlobalVar prim_fn_var, tvm::Array<te::Tensor> inputs,
+                       tvm::Array<te::Tensor> outputs, te::Schedule schedule,
+                       tvm::Array<Integer> shape_func_param_states, IRModule funcs) {
+  auto n = make_object<CachedFuncNode>();
+  n->target = target;
+  n->prim_fn_var = prim_fn_var;
+  n->inputs = inputs;
+  n->outputs = outputs;
+  n->schedule = schedule;
+  n->shape_func_param_states = shape_func_param_states;
+  n->funcs = funcs;
+  data_ = std::move(n);
+}
+
+Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
+  // for now, we always use int32 shape when possible
+  // even if the result of shape inference becomes int64.
+  Array<IndexExpr> res;
+  for (IndexExpr val : shape) {
+    const int64_t* pval = tir::as_const_int(val);
+    if (pval != nullptr) {
+#ifndef TVM_INDEX_DEFAULT_I64
+      ICHECK_LE(pval[0], std::numeric_limits<int32_t>::max());
+      ICHECK_GE(pval[0], std::numeric_limits<int32_t>::min());
+      res.push_back(IntImm(DataType::Int(32), *pval));
+#else
+      res.push_back(val);
+#endif  // TVM_INDEX_DEFAULT_I64
+    } else if (val->IsInstance<tir::AnyNode>()) {
+      res.push_back(val.as<tir::AnyNode>()->ToVar());
+    } else {
+      res.push_back(val);
+    }
+  }
+  return res;
+}
+
+// The getter to get schedule from compile engine.
+// Get schedule from functor.
+class ScheduleGetter : public backend::MemoizedExprTranslator<Array<te::Tensor>> {
+ public:
+  explicit ScheduleGetter(Target target)
+      : target_(target), device_copy_op_(Op::Get("device_copy")) {
+    // Whether to use auto_scheduler schedule.
+    use_auto_scheduler_ = backend::IsAutoSchedulerEnabled();
+  }
+
+  CachedFunc Create(const Function& prim_func, std::function<std::string(std::string)> renamer) {
+    Array<tvm::te::Tensor> fn_inputs;
+    for (Var param : prim_func->params) {
+      Array<tvm::te::Tensor> inputs;
+      if (const auto* ttype = param->checked_type().as<TensorTypeNode>()) {
+        tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype);
+        fn_inputs.push_back(tensor);
+        inputs.push_back(tensor);
+      } else {
+        // flatten tuple of tensor type.
+        const auto* tuple_type = param->type_as<TupleTypeNode>();
+        for (Type field : tuple_type->fields) {
+          const auto* ttype = field.as<TensorTypeNode>();
+          // TODO(@icemelon): Allow recursive tuple
+          ICHECK(ttype != nullptr);
+          tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype);
+          fn_inputs.push_back(tensor);
+          inputs.push_back(tensor);
+        }
+      }
+      memo_[param] = inputs;
+    }
+    readable_name_stream_ << "fused";
+    auto outputs = this->VisitExpr(prim_func->body);
+    auto candidate_name = readable_name_stream_.str();
+    constexpr static size_t kMaxFuncNameLength = 80;
+    if (candidate_name.size() > kMaxFuncNameLength) {
+      std::stringstream truncated_name;
+      truncated_name << candidate_name.substr(0, kMaxFuncNameLength);
+      truncated_name << "_" << std::hash<std::string>{}(candidate_name) << "_";
+      candidate_name = truncated_name.str();
+    }
+
+    ICHECK(anchor_op_.defined());
+    // Fusion over tupled results may leave identity relationships
+    // between inputs and outputs, and those should not be scheduled.
+    // Hence schedule only non PlaceholderOp outputs.
+    tvm::Array<te::Tensor> tensor_outs;
+    for (const auto& tensor : outputs) {
+      if (!tensor->op.as<te::PlaceholderOpNode>()) {
+        tensor_outs.push_back(tensor);
+      }
+    }
+
+    te::Schedule schedule;
+    // No need to register schedule for device copy op.
+    if (anchor_attrs_.as<DeviceCopyAttrs>() == nullptr) {
+      if (use_auto_scheduler_) {
+        const auto* fauto_schedule =
+            runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute");
+        ICHECK(fauto_schedule != nullptr)
+            << "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered";
+        ObjectRef obj = (*fauto_schedule)(tensor_outs);
+        if (obj.defined()) {
+          schedule = Downcast<te::Schedule>(obj);
+        }
+      }
+
+      // Use TOPI schdule if user specificed, or the function has no auto_scheduler schedule.
+      if (!schedule.defined()) {
+        ICHECK(anchor_implementation_.defined());
+        schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_);
+      }
+      for (const auto& scalar : scalars_) {
+        if (schedule->Contain(scalar)) {
+          schedule[scalar].compute_inline();
+        }
+      }
+    }
+
+    auto prim_fn_var = GlobalVar(candidate_name);
+    prim_fn_var->checked_type_ = prim_func->checked_type();
+    return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, {});
+  }
+
+  Array<te::Tensor> VisitExpr_(const VarNode* op) final {
+    LOG(FATAL) << "Free variable " << op->name_hint();
+    return {};
+  }
+
+  Array<te::Tensor> VisitExpr_(const ConstantNode* op) final {
+    using tir::make_const;
+    ICHECK(op->is_scalar());
+    void* data = op->data->data;
+    DataType dtype = DataType(op->data->dtype);
+    auto value = te::compute(
+        {},
+        [&](const Array<tvm::tir::Var>&) {
+          if (dtype == DataType::Int(32)) {
+            return make_const(dtype, static_cast<const int32_t*>(data)[0]);
+          } else if (dtype == DataType::Int(64)) {
+            return make_const(dtype, static_cast<const int64_t*>(data)[0]);
+          } else if (dtype == DataType::Float(32)) {
+            return make_const(dtype, static_cast<const float*>(data)[0]);
+          } else if (dtype == DataType::Float(64)) {
+            return make_const(dtype, static_cast<const double*>(data)[0]);
+          } else if (dtype == DataType::Bool()) {
+            return make_const(dtype, static_cast<const uint8_t*>(data)[0]);
+          } else {
+            LOG(FATAL) << "not handled";
+            return tvm::PrimExpr();
+          }
+        },
+        "compile_engine_const", topi::kBroadcast);
+    scalars_.push_back(value->op);
+    return {value};
+  }
+
+  Array<te::Tensor> VisitExpr_(const CallNode* call_node) final {
+    static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
+    static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call");
+    ICHECK(flower_call) << "relay.backend.lower_call is not registered.";
+
+    Array<te::Tensor> inputs;
+    int count_tuple = 0;
+    for (Expr arg : call_node->args) {
+      if (arg->checked_type().as<TupleTypeNode>()) {
+        ++count_tuple;
+      }
+      for (te::Tensor tensor : VisitExpr(arg)) {
+        inputs.push_back(tensor);
+      }
+    }
+    if (count_tuple) {
+      ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input";

Review comment:
       ```suggestion
         ICHECK_EQ(call_node->args.size(), 1U) << "Only functions with a single tuple input are allowed, but " << call_node->args.size() << " were provided.";
   ```

##########
File path: src/relay/backend/te_compiler_cache.cc
##########
@@ -0,0 +1,681 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "./te_compiler_cache.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+TVM_REGISTER_NODE_TYPE(LoweredOutputNode);
+TVM_REGISTER_NODE_TYPE(CachedFuncNode);
+TVM_REGISTER_NODE_TYPE(CCacheKeyNode);
+TVM_REGISTER_NODE_TYPE(CCacheValueNode);
+
+LoweredOutput::LoweredOutput(tvm::Array<te::Tensor> outputs, OpImplementation impl) {
+  auto n = make_object<LoweredOutputNode>();
+  n->outputs = std::move(outputs);
+  n->implementation = std::move(impl);
+  data_ = std::move(n);
+}
+
+CCacheKey::CCacheKey(Function source_func, Target target) {
+  auto n = make_object<CCacheKeyNode>();
+  n->source_func = std::move(source_func);
+  n->target = std::move(target);
+  data_ = std::move(n);
+}
+
+CachedFunc::CachedFunc(tvm::Target target, GlobalVar prim_fn_var, tvm::Array<te::Tensor> inputs,
+                       tvm::Array<te::Tensor> outputs, te::Schedule schedule,
+                       tvm::Array<Integer> shape_func_param_states, IRModule funcs) {
+  auto n = make_object<CachedFuncNode>();
+  n->target = target;
+  n->prim_fn_var = prim_fn_var;
+  n->inputs = inputs;
+  n->outputs = outputs;
+  n->schedule = schedule;
+  n->shape_func_param_states = shape_func_param_states;
+  n->funcs = funcs;
+  data_ = std::move(n);
+}
+
+Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
+  // for now, we always use int32 shape when possible
+  // even if the result of shape inference becomes int64.
+  Array<IndexExpr> res;
+  for (IndexExpr val : shape) {
+    const int64_t* pval = tir::as_const_int(val);
+    if (pval != nullptr) {
+#ifndef TVM_INDEX_DEFAULT_I64
+      ICHECK_LE(pval[0], std::numeric_limits<int32_t>::max());
+      ICHECK_GE(pval[0], std::numeric_limits<int32_t>::min());
+      res.push_back(IntImm(DataType::Int(32), *pval));
+#else
+      res.push_back(val);
+#endif  // TVM_INDEX_DEFAULT_I64
+    } else if (val->IsInstance<tir::AnyNode>()) {
+      res.push_back(val.as<tir::AnyNode>()->ToVar());
+    } else {
+      res.push_back(val);
+    }
+  }
+  return res;
+}
+
+// The getter to get schedule from compile engine.
+// Get schedule from functor.
+class ScheduleGetter : public backend::MemoizedExprTranslator<Array<te::Tensor>> {
+ public:
+  explicit ScheduleGetter(Target target)
+      : target_(target), device_copy_op_(Op::Get("device_copy")) {
+    // Whether to use auto_scheduler schedule.
+    use_auto_scheduler_ = backend::IsAutoSchedulerEnabled();
+  }
+
+  CachedFunc Create(const Function& prim_func, std::function<std::string(std::string)> renamer) {
+    Array<tvm::te::Tensor> fn_inputs;
+    for (Var param : prim_func->params) {
+      Array<tvm::te::Tensor> inputs;
+      if (const auto* ttype = param->checked_type().as<TensorTypeNode>()) {
+        tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype);
+        fn_inputs.push_back(tensor);
+        inputs.push_back(tensor);
+      } else {
+        // flatten tuple of tensor type.
+        const auto* tuple_type = param->type_as<TupleTypeNode>();
+        for (Type field : tuple_type->fields) {
+          const auto* ttype = field.as<TensorTypeNode>();
+          // TODO(@icemelon): Allow recursive tuple
+          ICHECK(ttype != nullptr);
+          tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype);
+          fn_inputs.push_back(tensor);
+          inputs.push_back(tensor);
+        }
+      }
+      memo_[param] = inputs;
+    }
+    readable_name_stream_ << "fused";
+    auto outputs = this->VisitExpr(prim_func->body);
+    auto candidate_name = readable_name_stream_.str();
+    constexpr static size_t kMaxFuncNameLength = 80;
+    if (candidate_name.size() > kMaxFuncNameLength) {
+      std::stringstream truncated_name;
+      truncated_name << candidate_name.substr(0, kMaxFuncNameLength);
+      truncated_name << "_" << std::hash<std::string>{}(candidate_name) << "_";
+      candidate_name = truncated_name.str();
+    }
+
+    ICHECK(anchor_op_.defined());
+    // Fusion over tupled results may leave identity relationships
+    // between inputs and outputs, and those should not be scheduled.
+    // Hence schedule only non PlaceholderOp outputs.
+    tvm::Array<te::Tensor> tensor_outs;
+    for (const auto& tensor : outputs) {
+      if (!tensor->op.as<te::PlaceholderOpNode>()) {
+        tensor_outs.push_back(tensor);
+      }
+    }
+
+    te::Schedule schedule;
+    // No need to register schedule for device copy op.
+    if (anchor_attrs_.as<DeviceCopyAttrs>() == nullptr) {
+      if (use_auto_scheduler_) {
+        const auto* fauto_schedule =
+            runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute");
+        ICHECK(fauto_schedule != nullptr)
+            << "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered";
+        ObjectRef obj = (*fauto_schedule)(tensor_outs);
+        if (obj.defined()) {
+          schedule = Downcast<te::Schedule>(obj);
+        }
+      }
+
+      // Use TOPI schdule if user specificed, or the function has no auto_scheduler schedule.
+      if (!schedule.defined()) {
+        ICHECK(anchor_implementation_.defined());
+        schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_);
+      }
+      for (const auto& scalar : scalars_) {
+        if (schedule->Contain(scalar)) {
+          schedule[scalar].compute_inline();
+        }
+      }
+    }
+
+    auto prim_fn_var = GlobalVar(candidate_name);
+    prim_fn_var->checked_type_ = prim_func->checked_type();
+    return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, {});
+  }
+
+  Array<te::Tensor> VisitExpr_(const VarNode* op) final {
+    LOG(FATAL) << "Free variable " << op->name_hint();
+    return {};
+  }
+
+  Array<te::Tensor> VisitExpr_(const ConstantNode* op) final {
+    using tir::make_const;
+    ICHECK(op->is_scalar());
+    void* data = op->data->data;
+    DataType dtype = DataType(op->data->dtype);
+    auto value = te::compute(
+        {},
+        [&](const Array<tvm::tir::Var>&) {
+          if (dtype == DataType::Int(32)) {
+            return make_const(dtype, static_cast<const int32_t*>(data)[0]);
+          } else if (dtype == DataType::Int(64)) {
+            return make_const(dtype, static_cast<const int64_t*>(data)[0]);
+          } else if (dtype == DataType::Float(32)) {
+            return make_const(dtype, static_cast<const float*>(data)[0]);
+          } else if (dtype == DataType::Float(64)) {
+            return make_const(dtype, static_cast<const double*>(data)[0]);
+          } else if (dtype == DataType::Bool()) {
+            return make_const(dtype, static_cast<const uint8_t*>(data)[0]);
+          } else {
+            LOG(FATAL) << "not handled";
+            return tvm::PrimExpr();
+          }
+        },
+        "compile_engine_const", topi::kBroadcast);
+    scalars_.push_back(value->op);
+    return {value};
+  }
+
+  Array<te::Tensor> VisitExpr_(const CallNode* call_node) final {
+    static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
+    static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call");
+    ICHECK(flower_call) << "relay.backend.lower_call is not registered.";
+
+    Array<te::Tensor> inputs;
+    int count_tuple = 0;
+    for (Expr arg : call_node->args) {
+      if (arg->checked_type().as<TupleTypeNode>()) {
+        ++count_tuple;
+      }
+      for (te::Tensor tensor : VisitExpr(arg)) {
+        inputs.push_back(tensor);
+      }
+    }
+    if (count_tuple) {
+      ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input";
+    }
+
+    ICHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops";
+    Op op = Downcast<Op>(call_node->op);
+
+    Array<te::Tensor> outputs;
+    OpImplementation impl;
+    // Skip fcompute for device copy operators as it is not registered.
+    if (op == device_copy_op_) {
+      const auto* copy_input = inputs[0].operator->();
+      outputs.push_back(te::Tensor(copy_input->shape, copy_input->dtype, te::Operation(), 0));
+    } else {
+      LoweredOutput lowered_out = (*flower_call)(GetRef<Call>(call_node), inputs, target_);
+      outputs = lowered_out->outputs;
+      impl = lowered_out->implementation;
+    }
+
+    int op_pattern = fpattern[op];
+    if (!use_auto_scheduler_ && op_pattern >= kCommReduce) {
+      ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce)
+          << "Cannot apply TOPI schedule to a primitive function with two complicated ops"
+          << " anchor=" << anchor_op_ << " current=" << op;
+    }
+    if (op_pattern >= anchor_op_pattern_) {
+      anchor_op_ = op;
+      anchor_attrs_ = call_node->attrs;
+      anchor_op_pattern_ = op_pattern;
+      anchor_implementation_ = impl;
+    }
+    if (outputs.size() != 1) {
+      const auto* tuple_type = call_node->checked_type().as<TupleTypeNode>();
+      ICHECK(tuple_type) << "Expect output to be a tuple type";

Review comment:
       Add actual type here

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -0,0 +1,372 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "te_compiler.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "te_compiler_cache.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+using namespace tvm::relay::transform;
+
+TVM_REGISTER_OBJECT_TYPE(TECompilerNode);
+
+class TECompilerImpl : public TECompilerNode {
+ public:
+  // Lower the function.
+  CachedFunc Lower(const CCacheKey& key) { return LowerInternal(key)->cached_func; }
+
+  // For now, build one module per function.
+  PackedFunc JIT(const CCacheKey& key) final {
+    CCacheValue value = LowerInternal(key);
+    if (value->packed_func != nullptr) {
+      return value->packed_func;
+    }
+    auto m = build(value->cached_func->funcs, key->target, Target(nullptr));
+    value->packed_func = m.GetFunction(value->cached_func->prim_fn_var->name_hint);
+    return value->packed_func;
+  }
+
+  CachedFunc LowerShapeFunc(const CCacheKey& key) final {
+    return LowerShapeFuncInternal(key)->cached_func;
+  }
+
+  Map<String, IRModule> GetLoweredFunctions() {
+    Map<String, IRModule> lowered_functions;
+    for (const auto& it : cache_) {
+      auto source_func = it.first;
+      auto lowered_func = it.second;
+      auto target = source_func->target;
+
+      if (!lowered_functions.count(target->str())) {
+        lowered_functions.Set(target->str(), IRModule(Map<GlobalVar, BaseFunc>({})));
+      }
+
+      lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs);
+    }
+    return lowered_functions;
+  }
+
+  Array<tvm::runtime::Module> LowerExternalFunctions() {
+    Array<tvm::runtime::Module> ret;
+    std::unordered_map<std::string, std::string> cached_symbol;
+    std::vector<CCacheKey> cached_ext_funcs;
+    for (const auto& it : cache_) {
+      auto src_func = it.first->source_func;
+      ICHECK(src_func.defined());
+      if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
+        auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
+        ICHECK(code_gen.defined()) << "No external codegen is set";
+        std::string code_gen_name = code_gen.value();
+        cached_ext_funcs.push_back(it.first);
+
+        auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+        ICHECK(symbol_name.defined()) << "No external symbol is set for:\n"
+                                      << AsText(src_func, false);
+
+        std::string sn = symbol_name.value();
+        if (cached_symbol.count(sn)) {
+          cached_symbol[sn] = code_gen_name;
+        } else {
+          ICHECK_NE(sn, code_gen_name)
+              << "Found duplicated symbol: " << sn << " for: " << code_gen_name;
+        }
+
+        std::string ext_name = "relay.ext." + code_gen_name;
+        auto pf = tvm::runtime::Registry::Get(ext_name);
+        ICHECK(pf) << "Failed to find the codegen tool for " << ext_name << "\n";
+        // No need to keep compiler attribute at this point, functions have been
+        // extracted for specific codegen.
+        src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue<ObjectRef>());
+        runtime::Module ext_mod = (*pf)(src_func);
+
+        ICHECK(ext_mod.defined()) << "No external runtime is generated.";
+        ret.push_back(ext_mod);
+      }
+    }
+
+    // No need to cache external functions as we collected them all to create
+    // external runtime modules.
+    for (const auto& it : cached_ext_funcs) {
+      cache_.erase(it);
+    }
+    return ret;
+  }
+
+  void Clear() final { cache_.clear(); }
+
+  // List all items in the cache.
+  Array<ObjectRef> ListItems() {
+    std::lock_guard<std::mutex> lock(mutex_);
+    Array<ObjectRef> items;
+    for (auto& kv : cache_) {
+      items.push_back(kv.first);
+      items.push_back(kv.second);
+    }
+    return items;
+  }
+
+  /*!
+   * \brief Get the cache key of the function that is being lowered currently
+   * \return the cache key
+   */
+  CCacheKey GetCurrentCCacheKey() { return cur_ccache_key_; }
+
+ private:
+  // implement lowered func
+  CCacheValue LowerInternal(const CCacheKey& key) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    CCacheValue value;
+    auto it = cache_.find(key);
+    if (it != cache_.end()) {
+      it->second->use_count += 1;
+      if (it->second->cached_func.defined()) return it->second;
+      value = it->second;
+    } else {
+      value = CCacheValue(make_object<CCacheValueNode>());
+      value->use_count = 0;
+      if (!backend::IsCompileEngineCacheDisabled()) {
+        cache_[key] = value;
+      }
+    }
+    cur_ccache_key_ = key;
+
+    // No need to lower external functions for now. We will invoke the external
+    // codegen tool once and lower all functions together.
+    if (key->source_func->GetAttr<String>(attr::kCompiler).defined()) {
+      auto ir_module = IRModule();
+      const auto name_node = key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+      ICHECK(name_node.defined()) << "External function has not been attached a name yet.";

Review comment:
       Could you add the function name to this?

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -0,0 +1,372 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "te_compiler.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "te_compiler_cache.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+using namespace tvm::relay::transform;
+
+TVM_REGISTER_OBJECT_TYPE(TECompilerNode);
+
+class TECompilerImpl : public TECompilerNode {
+ public:
+  // Lower the function.
+  CachedFunc Lower(const CCacheKey& key) { return LowerInternal(key)->cached_func; }
+
+  // For now, build one module per function.
+  PackedFunc JIT(const CCacheKey& key) final {
+    CCacheValue value = LowerInternal(key);
+    if (value->packed_func != nullptr) {
+      return value->packed_func;
+    }
+    auto m = build(value->cached_func->funcs, key->target, Target(nullptr));
+    value->packed_func = m.GetFunction(value->cached_func->prim_fn_var->name_hint);
+    return value->packed_func;
+  }
+
+  CachedFunc LowerShapeFunc(const CCacheKey& key) final {
+    return LowerShapeFuncInternal(key)->cached_func;
+  }
+
+  Map<String, IRModule> GetLoweredFunctions() {
+    Map<String, IRModule> lowered_functions;
+    for (const auto& it : cache_) {
+      auto source_func = it.first;
+      auto lowered_func = it.second;
+      auto target = source_func->target;
+
+      if (!lowered_functions.count(target->str())) {
+        lowered_functions.Set(target->str(), IRModule(Map<GlobalVar, BaseFunc>({})));
+      }
+
+      lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs);
+    }
+    return lowered_functions;
+  }
+
+  Array<tvm::runtime::Module> LowerExternalFunctions() {
+    Array<tvm::runtime::Module> ret;
+    std::unordered_map<std::string, std::string> cached_symbol;
+    std::vector<CCacheKey> cached_ext_funcs;
+    for (const auto& it : cache_) {
+      auto src_func = it.first->source_func;
+      ICHECK(src_func.defined());
+      if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
+        auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
+        ICHECK(code_gen.defined()) << "No external codegen is set";
+        std::string code_gen_name = code_gen.value();
+        cached_ext_funcs.push_back(it.first);
+
+        auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+        ICHECK(symbol_name.defined()) << "No external symbol is set for:\n"
+                                      << AsText(src_func, false);
+
+        std::string sn = symbol_name.value();
+        if (cached_symbol.count(sn)) {
+          cached_symbol[sn] = code_gen_name;
+        } else {
+          ICHECK_NE(sn, code_gen_name)
+              << "Found duplicated symbol: " << sn << " for: " << code_gen_name;
+        }
+
+        std::string ext_name = "relay.ext." + code_gen_name;
+        auto pf = tvm::runtime::Registry::Get(ext_name);
+        ICHECK(pf) << "Failed to find the codegen tool for " << ext_name << "\n";
+        // No need to keep compiler attribute at this point, functions have been
+        // extracted for specific codegen.
+        src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue<ObjectRef>());
+        runtime::Module ext_mod = (*pf)(src_func);
+
+        ICHECK(ext_mod.defined()) << "No external runtime is generated.";
+        ret.push_back(ext_mod);
+      }
+    }
+
+    // No need to cache external functions as we collected them all to create
+    // external runtime modules.
+    for (const auto& it : cached_ext_funcs) {
+      cache_.erase(it);
+    }
+    return ret;
+  }
+
+  void Clear() final { cache_.clear(); }
+
+  // List all items in the cache.
+  Array<ObjectRef> ListItems() {
+    std::lock_guard<std::mutex> lock(mutex_);
+    Array<ObjectRef> items;
+    for (auto& kv : cache_) {
+      items.push_back(kv.first);
+      items.push_back(kv.second);
+    }
+    return items;
+  }
+
+  /*!
+   * \brief Get the cache key of the function that is being lowered currently
+   * \return the cache key
+   */
+  CCacheKey GetCurrentCCacheKey() { return cur_ccache_key_; }
+
+ private:
+  // implement lowered func
+  CCacheValue LowerInternal(const CCacheKey& key) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    CCacheValue value;
+    auto it = cache_.find(key);
+    if (it != cache_.end()) {
+      it->second->use_count += 1;
+      if (it->second->cached_func.defined()) return it->second;
+      value = it->second;
+    } else {
+      value = CCacheValue(make_object<CCacheValueNode>());
+      value->use_count = 0;
+      if (!backend::IsCompileEngineCacheDisabled()) {
+        cache_[key] = value;
+      }
+    }
+    cur_ccache_key_ = key;
+
+    // No need to lower external functions for now. We will invoke the external
+    // codegen tool once and lower all functions together.
+    if (key->source_func->GetAttr<String>(attr::kCompiler).defined()) {
+      auto ir_module = IRModule();
+      const auto name_node = key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+      ICHECK(name_node.defined()) << "External function has not been attached a name yet.";
+      auto func_name = std::string(name_node.value());
+      auto target = Target("ext_dev");
+      auto global_var = GlobalVar(func_name);
+      global_var->checked_type_ = key->source_func->checked_type();
+      ir_module->Add(global_var, key->source_func);
+      value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule(), {}, ir_module);
+      return value;
+    }
+    // Enforce use the target.
+    With<Target> target_scope(key->target);
+
+    ICHECK(!value->cached_func.defined());
+    auto cfunc = PrimFuncFor(key->source_func, key->target,
+                             [&](std::string name) { return GetUniqueName(name, name_map_); });
+
+    // Skip lowering for device copy node.
+    const Expr body = (key->source_func)->body;
+    if (const CallNode* call_node = body.as<CallNode>()) {
+      if (call_node->attrs.as<DeviceCopyAttrs>()) {
+        value->cached_func = cfunc;
+        return value;
+      }
+    }
+
+    std::cout << "Input Size: " << cfunc->inputs.size() << std::endl;
+    std::cout << "Output Size: " << cfunc->outputs.size() << std::endl;
+    // NOTE: array will copy on write.
+    Array<te::Tensor> all_args = Array<te::Tensor>(cfunc->inputs);
+    for (te::Tensor arg : cfunc->outputs) {
+      all_args.push_back(arg);
+    }
+
+    std::cout << "Allargs Size: " << all_args.size() << std::endl;

Review comment:
       remove?

##########
File path: src/runtime/graph/graph_runtime.cc
##########
@@ -428,6 +429,7 @@ std::pair<std::function<void()>, std::shared_ptr<GraphRuntime::OpArgs> > GraphRu
   ICHECK(pf != nullptr) << "no such function in module: " << param.func_name;
 
   auto fexec = [arg_ptr, pf]() {
+    std::cout << "Number of args: " << static_cast<int>(arg_ptr->arg_values.size());

Review comment:
       remove?

##########
File path: src/relay/backend/te_compiler_cache.cc
##########
@@ -0,0 +1,681 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "./te_compiler_cache.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+TVM_REGISTER_NODE_TYPE(LoweredOutputNode);
+TVM_REGISTER_NODE_TYPE(CachedFuncNode);
+TVM_REGISTER_NODE_TYPE(CCacheKeyNode);
+TVM_REGISTER_NODE_TYPE(CCacheValueNode);
+
+LoweredOutput::LoweredOutput(tvm::Array<te::Tensor> outputs, OpImplementation impl) {
+  auto n = make_object<LoweredOutputNode>();
+  n->outputs = std::move(outputs);
+  n->implementation = std::move(impl);
+  data_ = std::move(n);
+}
+
+CCacheKey::CCacheKey(Function source_func, Target target) {
+  auto n = make_object<CCacheKeyNode>();
+  n->source_func = std::move(source_func);
+  n->target = std::move(target);
+  data_ = std::move(n);
+}
+
+CachedFunc::CachedFunc(tvm::Target target, GlobalVar prim_fn_var, tvm::Array<te::Tensor> inputs,
+                       tvm::Array<te::Tensor> outputs, te::Schedule schedule,
+                       tvm::Array<Integer> shape_func_param_states, IRModule funcs) {
+  auto n = make_object<CachedFuncNode>();
+  n->target = target;
+  n->prim_fn_var = prim_fn_var;
+  n->inputs = inputs;
+  n->outputs = outputs;
+  n->schedule = schedule;
+  n->shape_func_param_states = shape_func_param_states;
+  n->funcs = funcs;
+  data_ = std::move(n);
+}
+
+Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
+  // for now, we always use int32 shape when possible
+  // even if the result of shape inference becomes int64.
+  Array<IndexExpr> res;
+  for (IndexExpr val : shape) {
+    const int64_t* pval = tir::as_const_int(val);
+    if (pval != nullptr) {
+#ifndef TVM_INDEX_DEFAULT_I64
+      ICHECK_LE(pval[0], std::numeric_limits<int32_t>::max());
+      ICHECK_GE(pval[0], std::numeric_limits<int32_t>::min());
+      res.push_back(IntImm(DataType::Int(32), *pval));
+#else
+      res.push_back(val);
+#endif  // TVM_INDEX_DEFAULT_I64
+    } else if (val->IsInstance<tir::AnyNode>()) {
+      res.push_back(val.as<tir::AnyNode>()->ToVar());
+    } else {
+      res.push_back(val);
+    }
+  }
+  return res;
+}
+
+// The getter to get schedule from compile engine.
+// Get schedule from functor.
+class ScheduleGetter : public backend::MemoizedExprTranslator<Array<te::Tensor>> {
+ public:
+  explicit ScheduleGetter(Target target)
+      : target_(target), device_copy_op_(Op::Get("device_copy")) {
+    // Whether to use auto_scheduler schedule.
+    use_auto_scheduler_ = backend::IsAutoSchedulerEnabled();
+  }
+
+  CachedFunc Create(const Function& prim_func, std::function<std::string(std::string)> renamer) {
+    Array<tvm::te::Tensor> fn_inputs;
+    for (Var param : prim_func->params) {
+      Array<tvm::te::Tensor> inputs;
+      if (const auto* ttype = param->checked_type().as<TensorTypeNode>()) {
+        tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype);
+        fn_inputs.push_back(tensor);
+        inputs.push_back(tensor);
+      } else {
+        // flatten tuple of tensor type.
+        const auto* tuple_type = param->type_as<TupleTypeNode>();
+        for (Type field : tuple_type->fields) {
+          const auto* ttype = field.as<TensorTypeNode>();
+          // TODO(@icemelon): Allow recursive tuple
+          ICHECK(ttype != nullptr);
+          tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype);
+          fn_inputs.push_back(tensor);
+          inputs.push_back(tensor);
+        }
+      }
+      memo_[param] = inputs;
+    }
+    readable_name_stream_ << "fused";
+    auto outputs = this->VisitExpr(prim_func->body);
+    auto candidate_name = readable_name_stream_.str();
+    constexpr static size_t kMaxFuncNameLength = 80;
+    if (candidate_name.size() > kMaxFuncNameLength) {
+      std::stringstream truncated_name;
+      truncated_name << candidate_name.substr(0, kMaxFuncNameLength);
+      truncated_name << "_" << std::hash<std::string>{}(candidate_name) << "_";
+      candidate_name = truncated_name.str();
+    }
+
+    ICHECK(anchor_op_.defined());
+    // Fusion over tupled results may leave identity relationships
+    // between inputs and outputs, and those should not be scheduled.
+    // Hence schedule only non PlaceholderOp outputs.
+    tvm::Array<te::Tensor> tensor_outs;
+    for (const auto& tensor : outputs) {
+      if (!tensor->op.as<te::PlaceholderOpNode>()) {
+        tensor_outs.push_back(tensor);
+      }
+    }
+
+    te::Schedule schedule;
+    // No need to register schedule for device copy op.
+    if (anchor_attrs_.as<DeviceCopyAttrs>() == nullptr) {
+      if (use_auto_scheduler_) {
+        const auto* fauto_schedule =
+            runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute");
+        ICHECK(fauto_schedule != nullptr)
+            << "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered";
+        ObjectRef obj = (*fauto_schedule)(tensor_outs);
+        if (obj.defined()) {
+          schedule = Downcast<te::Schedule>(obj);
+        }
+      }
+
+      // Use TOPI schdule if user specificed, or the function has no auto_scheduler schedule.
+      if (!schedule.defined()) {
+        ICHECK(anchor_implementation_.defined());
+        schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_);
+      }
+      for (const auto& scalar : scalars_) {
+        if (schedule->Contain(scalar)) {
+          schedule[scalar].compute_inline();
+        }
+      }
+    }
+
+    auto prim_fn_var = GlobalVar(candidate_name);
+    prim_fn_var->checked_type_ = prim_func->checked_type();
+    return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, {});
+  }
+
+  Array<te::Tensor> VisitExpr_(const VarNode* op) final {
+    LOG(FATAL) << "Free variable " << op->name_hint();
+    return {};
+  }
+
+  Array<te::Tensor> VisitExpr_(const ConstantNode* op) final {
+    using tir::make_const;
+    ICHECK(op->is_scalar());
+    void* data = op->data->data;
+    DataType dtype = DataType(op->data->dtype);
+    auto value = te::compute(
+        {},
+        [&](const Array<tvm::tir::Var>&) {
+          if (dtype == DataType::Int(32)) {
+            return make_const(dtype, static_cast<const int32_t*>(data)[0]);
+          } else if (dtype == DataType::Int(64)) {
+            return make_const(dtype, static_cast<const int64_t*>(data)[0]);
+          } else if (dtype == DataType::Float(32)) {
+            return make_const(dtype, static_cast<const float*>(data)[0]);
+          } else if (dtype == DataType::Float(64)) {
+            return make_const(dtype, static_cast<const double*>(data)[0]);
+          } else if (dtype == DataType::Bool()) {
+            return make_const(dtype, static_cast<const uint8_t*>(data)[0]);
+          } else {
+            LOG(FATAL) << "not handled";
+            return tvm::PrimExpr();
+          }
+        },
+        "compile_engine_const", topi::kBroadcast);
+    scalars_.push_back(value->op);
+    return {value};
+  }
+
+  Array<te::Tensor> VisitExpr_(const CallNode* call_node) final {
+    static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
+    static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call");
+    ICHECK(flower_call) << "relay.backend.lower_call is not registered.";
+
+    Array<te::Tensor> inputs;
+    int count_tuple = 0;
+    for (Expr arg : call_node->args) {
+      if (arg->checked_type().as<TupleTypeNode>()) {
+        ++count_tuple;
+      }
+      for (te::Tensor tensor : VisitExpr(arg)) {
+        inputs.push_back(tensor);
+      }
+    }
+    if (count_tuple) {
+      ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input";
+    }
+
+    ICHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops";
+    Op op = Downcast<Op>(call_node->op);
+
+    Array<te::Tensor> outputs;
+    OpImplementation impl;
+    // Skip fcompute for device copy operators as it is not registered.
+    if (op == device_copy_op_) {
+      const auto* copy_input = inputs[0].operator->();
+      outputs.push_back(te::Tensor(copy_input->shape, copy_input->dtype, te::Operation(), 0));
+    } else {
+      LoweredOutput lowered_out = (*flower_call)(GetRef<Call>(call_node), inputs, target_);
+      outputs = lowered_out->outputs;
+      impl = lowered_out->implementation;
+    }
+
+    int op_pattern = fpattern[op];
+    if (!use_auto_scheduler_ && op_pattern >= kCommReduce) {
+      ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce)
+          << "Cannot apply TOPI schedule to a primitive function with two complicated ops"
+          << " anchor=" << anchor_op_ << " current=" << op;
+    }
+    if (op_pattern >= anchor_op_pattern_) {
+      anchor_op_ = op;
+      anchor_attrs_ = call_node->attrs;
+      anchor_op_pattern_ = op_pattern;
+      anchor_implementation_ = impl;
+    }
+    if (outputs.size() != 1) {
+      const auto* tuple_type = call_node->checked_type().as<TupleTypeNode>();
+      ICHECK(tuple_type) << "Expect output to be a tuple type";
+      ICHECK_EQ(tuple_type->fields.size(), outputs.size());
+    }
+    // Set the name to `__copy`. It will be detected in graph runtime to perform
+    // data copy across devices.
+    if (op == device_copy_op_) {
+      readable_name_stream_.str(std::string());
+      readable_name_stream_ << "__copy";
+    } else {
+      readable_name_stream_ << '_' << op->name;
+    }
+    return outputs;
+  }
+
+  Array<te::Tensor> VisitExpr_(const FunctionNode* op) final {
+    LOG(FATAL) << "Do not support sub function";

Review comment:
       Can you add what is supported?

##########
File path: tests/python/relay/test_backend_graph_runtime.py
##########
@@ -231,10 +232,12 @@ def test_graph_executor_nested_tuples():
 
 
 if __name__ == "__main__":
-    test_plan_memory()
-    test_with_params()
-    test_add_op_scalar()
     test_add_op_tensor()
-    test_add_op_broadcast()
-    test_gru_like()
-    test_compile_nested_tuples()
+    # test_plan_memory()
+    # test_with_params()
+    # test_add_op_scalar()
+    # test_add_op_tensor()
+    # test_add_op_broadcast()
+    # test_gru_like()
+    # test_compile_nested_tuples()
+    # test_add_op_tensor()

Review comment:
       uncomment?

##########
File path: src/relay/backend/te_compiler_cache.cc
##########
@@ -0,0 +1,681 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "./te_compiler_cache.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+TVM_REGISTER_NODE_TYPE(LoweredOutputNode);
+TVM_REGISTER_NODE_TYPE(CachedFuncNode);
+TVM_REGISTER_NODE_TYPE(CCacheKeyNode);
+TVM_REGISTER_NODE_TYPE(CCacheValueNode);
+
+LoweredOutput::LoweredOutput(tvm::Array<te::Tensor> outputs, OpImplementation impl) {
+  auto n = make_object<LoweredOutputNode>();
+  n->outputs = std::move(outputs);
+  n->implementation = std::move(impl);
+  data_ = std::move(n);
+}
+
+CCacheKey::CCacheKey(Function source_func, Target target) {
+  auto n = make_object<CCacheKeyNode>();
+  n->source_func = std::move(source_func);
+  n->target = std::move(target);
+  data_ = std::move(n);
+}
+
+CachedFunc::CachedFunc(tvm::Target target, GlobalVar prim_fn_var, tvm::Array<te::Tensor> inputs,
+                       tvm::Array<te::Tensor> outputs, te::Schedule schedule,
+                       tvm::Array<Integer> shape_func_param_states, IRModule funcs) {
+  auto n = make_object<CachedFuncNode>();
+  n->target = target;
+  n->prim_fn_var = prim_fn_var;
+  n->inputs = inputs;
+  n->outputs = outputs;
+  n->schedule = schedule;
+  n->shape_func_param_states = shape_func_param_states;
+  n->funcs = funcs;
+  data_ = std::move(n);
+}
+
+Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
+  // for now, we always use int32 shape when possible
+  // even if the result of shape inference becomes int64.
+  Array<IndexExpr> res;
+  for (IndexExpr val : shape) {
+    const int64_t* pval = tir::as_const_int(val);
+    if (pval != nullptr) {
+#ifndef TVM_INDEX_DEFAULT_I64
+      ICHECK_LE(pval[0], std::numeric_limits<int32_t>::max());
+      ICHECK_GE(pval[0], std::numeric_limits<int32_t>::min());
+      res.push_back(IntImm(DataType::Int(32), *pval));
+#else
+      res.push_back(val);
+#endif  // TVM_INDEX_DEFAULT_I64
+    } else if (val->IsInstance<tir::AnyNode>()) {
+      res.push_back(val.as<tir::AnyNode>()->ToVar());
+    } else {
+      res.push_back(val);
+    }
+  }
+  return res;
+}
+
+// The getter to get schedule from compile engine.
+// Get schedule from functor.
+class ScheduleGetter : public backend::MemoizedExprTranslator<Array<te::Tensor>> {
+ public:
+  explicit ScheduleGetter(Target target)
+      : target_(target), device_copy_op_(Op::Get("device_copy")) {
+    // Whether to use auto_scheduler schedule.
+    use_auto_scheduler_ = backend::IsAutoSchedulerEnabled();
+  }
+
+  CachedFunc Create(const Function& prim_func, std::function<std::string(std::string)> renamer) {
+    Array<tvm::te::Tensor> fn_inputs;
+    for (Var param : prim_func->params) {
+      Array<tvm::te::Tensor> inputs;
+      if (const auto* ttype = param->checked_type().as<TensorTypeNode>()) {
+        tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype);
+        fn_inputs.push_back(tensor);
+        inputs.push_back(tensor);
+      } else {
+        // flatten tuple of tensor type.
+        const auto* tuple_type = param->type_as<TupleTypeNode>();
+        for (Type field : tuple_type->fields) {
+          const auto* ttype = field.as<TensorTypeNode>();
+          // TODO(@icemelon): Allow recursive tuple
+          ICHECK(ttype != nullptr);
+          tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype);
+          fn_inputs.push_back(tensor);
+          inputs.push_back(tensor);
+        }
+      }
+      memo_[param] = inputs;
+    }
+    readable_name_stream_ << "fused";
+    auto outputs = this->VisitExpr(prim_func->body);
+    auto candidate_name = readable_name_stream_.str();
+    constexpr static size_t kMaxFuncNameLength = 80;
+    if (candidate_name.size() > kMaxFuncNameLength) {
+      std::stringstream truncated_name;
+      truncated_name << candidate_name.substr(0, kMaxFuncNameLength);
+      truncated_name << "_" << std::hash<std::string>{}(candidate_name) << "_";
+      candidate_name = truncated_name.str();
+    }
+
+    ICHECK(anchor_op_.defined());
+    // Fusion over tupled results may leave identity relationships
+    // between inputs and outputs, and those should not be scheduled.
+    // Hence schedule only non PlaceholderOp outputs.
+    tvm::Array<te::Tensor> tensor_outs;
+    for (const auto& tensor : outputs) {
+      if (!tensor->op.as<te::PlaceholderOpNode>()) {
+        tensor_outs.push_back(tensor);
+      }
+    }
+
+    te::Schedule schedule;
+    // No need to register schedule for device copy op.
+    if (anchor_attrs_.as<DeviceCopyAttrs>() == nullptr) {
+      if (use_auto_scheduler_) {
+        const auto* fauto_schedule =
+            runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute");
+        ICHECK(fauto_schedule != nullptr)
+            << "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered";
+        ObjectRef obj = (*fauto_schedule)(tensor_outs);
+        if (obj.defined()) {
+          schedule = Downcast<te::Schedule>(obj);
+        }
+      }
+
+      // Use TOPI schdule if user specificed, or the function has no auto_scheduler schedule.
+      if (!schedule.defined()) {
+        ICHECK(anchor_implementation_.defined());
+        schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_);
+      }
+      for (const auto& scalar : scalars_) {
+        if (schedule->Contain(scalar)) {
+          schedule[scalar].compute_inline();
+        }
+      }
+    }
+
+    auto prim_fn_var = GlobalVar(candidate_name);
+    prim_fn_var->checked_type_ = prim_func->checked_type();
+    return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, {});
+  }
+
+  Array<te::Tensor> VisitExpr_(const VarNode* op) final {
+    LOG(FATAL) << "Free variable " << op->name_hint();
+    return {};
+  }
+
+  Array<te::Tensor> VisitExpr_(const ConstantNode* op) final {
+    using tir::make_const;
+    ICHECK(op->is_scalar());
+    void* data = op->data->data;
+    DataType dtype = DataType(op->data->dtype);
+    auto value = te::compute(
+        {},
+        [&](const Array<tvm::tir::Var>&) {
+          if (dtype == DataType::Int(32)) {
+            return make_const(dtype, static_cast<const int32_t*>(data)[0]);
+          } else if (dtype == DataType::Int(64)) {
+            return make_const(dtype, static_cast<const int64_t*>(data)[0]);
+          } else if (dtype == DataType::Float(32)) {
+            return make_const(dtype, static_cast<const float*>(data)[0]);
+          } else if (dtype == DataType::Float(64)) {
+            return make_const(dtype, static_cast<const double*>(data)[0]);
+          } else if (dtype == DataType::Bool()) {
+            return make_const(dtype, static_cast<const uint8_t*>(data)[0]);
+          } else {
+            LOG(FATAL) << "not handled";
+            return tvm::PrimExpr();
+          }
+        },
+        "compile_engine_const", topi::kBroadcast);
+    scalars_.push_back(value->op);
+    return {value};
+  }
+
+  Array<te::Tensor> VisitExpr_(const CallNode* call_node) final {
+    static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
+    static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call");
+    ICHECK(flower_call) << "relay.backend.lower_call is not registered.";
+
+    Array<te::Tensor> inputs;
+    int count_tuple = 0;
+    for (Expr arg : call_node->args) {
+      if (arg->checked_type().as<TupleTypeNode>()) {
+        ++count_tuple;
+      }
+      for (te::Tensor tensor : VisitExpr(arg)) {
+        inputs.push_back(tensor);
+      }
+    }
+    if (count_tuple) {
+      ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input";
+    }
+
+    ICHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops";
+    Op op = Downcast<Op>(call_node->op);
+
+    Array<te::Tensor> outputs;
+    OpImplementation impl;
+    // Skip fcompute for device copy operators as it is not registered.
+    if (op == device_copy_op_) {
+      const auto* copy_input = inputs[0].operator->();
+      outputs.push_back(te::Tensor(copy_input->shape, copy_input->dtype, te::Operation(), 0));
+    } else {
+      LoweredOutput lowered_out = (*flower_call)(GetRef<Call>(call_node), inputs, target_);
+      outputs = lowered_out->outputs;
+      impl = lowered_out->implementation;
+    }
+
+    int op_pattern = fpattern[op];
+    if (!use_auto_scheduler_ && op_pattern >= kCommReduce) {
+      ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce)
+          << "Cannot apply TOPI schedule to a primitive function with two complicated ops"
+          << " anchor=" << anchor_op_ << " current=" << op;
+    }
+    if (op_pattern >= anchor_op_pattern_) {
+      anchor_op_ = op;
+      anchor_attrs_ = call_node->attrs;
+      anchor_op_pattern_ = op_pattern;
+      anchor_implementation_ = impl;
+    }
+    if (outputs.size() != 1) {
+      const auto* tuple_type = call_node->checked_type().as<TupleTypeNode>();
+      ICHECK(tuple_type) << "Expect output to be a tuple type";
+      ICHECK_EQ(tuple_type->fields.size(), outputs.size());
+    }
+    // Set the name to `__copy`. It will be detected in graph runtime to perform
+    // data copy across devices.
+    if (op == device_copy_op_) {
+      readable_name_stream_.str(std::string());
+      readable_name_stream_ << "__copy";
+    } else {
+      readable_name_stream_ << '_' << op->name;
+    }
+    return outputs;
+  }
+
+  Array<te::Tensor> VisitExpr_(const FunctionNode* op) final {
+    LOG(FATAL) << "Do not support sub function";
+    return Array<te::Tensor>();
+  }
+
+  Array<te::Tensor> VisitExpr_(const LetNode* op) final {
+    Array<te::Tensor> val = VisitExpr(op->value);
+    ICHECK(!memo_.count(op->var));
+    memo_[op->var] = val;
+    return VisitExpr(op->body);
+  }
+
+  Array<te::Tensor> VisitExpr_(const TupleNode* op) final {
+    Array<te::Tensor> fields;
+    for (Expr field : op->fields) {
+      ICHECK(field->checked_type().as<TensorTypeNode>()) << "Only allow Tuple of Tensor";
+      Array<te::Tensor> res = VisitExpr(field);
+      ICHECK_EQ(res.size(), 1);
+      fields.push_back(res[0]);
+    }
+    return fields;
+  }
+
+  Array<te::Tensor> VisitExpr_(const TupleGetItemNode* op) final {
+    const auto* tuple_type = op->tuple->type_as<TupleTypeNode>();
+    Array<te::Tensor> tuple = VisitExpr(op->tuple);
+    ICHECK_EQ(tuple_type->fields.size(), tuple.size());
+    ICHECK_GE(op->index, 0);
+    ICHECK_LT(static_cast<size_t>(op->index), tuple.size());
+    return {tuple[op->index]};
+  }
+
+ private:
+  tvm::Target target_;
+  Op anchor_op_;
+  Attrs anchor_attrs_;
+  int anchor_op_pattern_{0};
+  OpImplementation anchor_implementation_;
+  std::ostringstream readable_name_stream_;
+  Array<te::Operation> scalars_;
+  bool use_auto_scheduler_;
+  // Cache device copy op for equivalence checking to reduce registry lookup
+  // overhead for each invocation of call node when retrieving schedules.
+  const Op& device_copy_op_;
+};
+
+/*!
+ * \brief Create schedule for target.
+ * \param source_func The primitive function to be lowered.
+ * \param target The target we want to create schedule for.
+ * \return Pair of schedule and cache.
+ *  The funcs field in cache is not yet populated.
+ */
+CachedFunc PrimFuncFor(const Function& source_func, const Target& target,
+                       std::function<std::string(std::string)> renamer) {
+  return ScheduleGetter(target).Create(source_func, renamer);
+}
+
+// Creates shape function from functor.
+class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>> {
+ public:
+  MakeShapeFunc() {}
+
+  CachedFunc Create(const Function& prim_func, const Target& target,
+                    std::function<std::string(std::string)> renamer) {
+    Array<te::Tensor> inputs;
+    TShapeDataDependent shape_func_param_states;
+
+    for (auto param : prim_func->params) {
+      param_states_[param] = kNoNeed;
+      Array<tvm::te::Tensor> data_inputs;
+      Array<tvm::te::Tensor> shape_inputs;
+
+      auto add_placeholder = [&data_inputs, &shape_inputs](const TensorTypeNode* ttype) {
+        // Add data placeholder
+        Shape shape = GetShape(ttype->shape);
+        tvm::te::Tensor data_tensor = tvm::te::placeholder(shape, ttype->dtype);
+        data_inputs.push_back(data_tensor);
+        // Add shape placeholder
+        int64_t ndim = shape.size();
+        Shape sshape;
+        if (ndim > 0) {
+          sshape.push_back(tvm::Integer(ndim));
+        }
+        tvm::te::Tensor shape_tensor = tvm::te::placeholder(sshape, DataType::Int(64));
+        shape_inputs.push_back(shape_tensor);
+      };
+
+      if (const auto* ttype = param->checked_type().as<TensorTypeNode>()) {
+        add_placeholder(ttype);
+      } else {
+        // flatten tuple of tensor type.
+        const auto* tuple_type = param->type_as<TupleTypeNode>();
+        // TODO(@icemelon): Support recursive tuple
+        ICHECK(tuple_type);
+        for (Type field : tuple_type->fields) {
+          const auto* ttype = field.as<TensorTypeNode>();
+          ICHECK(ttype);
+          add_placeholder(ttype);
+        }
+      }
+      param_data_[param] = data_inputs;
+      param_shapes_[param] = shape_inputs;
+    }
+
+    // Setup the name;
+    readable_name_stream_ << "shape_func";
+
+    // Create the `te::Tensor`s which represent the output.
+    auto outputs = VisitExpr(prim_func->body);
+
+    // Generate a name.
+    auto candidate_name = readable_name_stream_.str();
+    constexpr static size_t kMaxFuncNameLength = 80;
+    if (candidate_name.size() > kMaxFuncNameLength) {
+      std::stringstream truncated_name;
+      truncated_name << candidate_name.substr(0, kMaxFuncNameLength);
+      truncated_name << "_" << std::hash<std::string>{}(candidate_name) << "_";
+      candidate_name = truncated_name.str();
+    }
+
+    auto func_name = renamer(candidate_name);
+
+    // Set all the inputs correctly.
+    for (auto param : prim_func->params) {
+      int state = param_states_[param];
+      shape_func_param_states.push_back(IntImm(DataType::Int(32), state));
+      if (state & kNeedInputData) {
+        for (auto t : param_data_[param]) {
+          inputs.push_back(t);
+        }
+      }
+      if (state & kNeedInputShape) {
+        for (auto t : param_shapes_[param]) {
+          inputs.push_back(t);
+        }
+      }
+    }
+
+    auto prim_fn_gvar = GlobalVar(func_name);
+    prim_fn_gvar->checked_type_ = prim_func->checked_type();
+
+    // generate schedule for shape func
+    Array<te::Operation> out_ops;
+    for (auto t : outputs) {
+      out_ops.push_back(t->op);
+    }
+    auto schedule = te::create_schedule(out_ops);
+    tvm::te::AutoInlineInjective(schedule);
+    for (const auto& scalar : scalars_) {
+      auto scalar_op = scalar->op;
+      if (schedule->Contain(scalar_op)) {
+        schedule[scalar_op].compute_inline();
+      }
+    }
+
+    Array<te::Tensor> all_args = Array<te::Tensor>(inputs);
+    for (te::Tensor arg : outputs) {
+      all_args.push_back(arg);
+    }
+
+    using tvm::transform::PassContext;
+    With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
+
+    std::unordered_map<te::Tensor, tir::Buffer> binds;
+    auto ir_module = tvm::lower(schedule, all_args, func_name, binds);
+
+    return CachedFunc(target, prim_fn_gvar, inputs, outputs, schedule, shape_func_param_states,
+                      ir_module);
+  }
+
+  Array<te::Tensor> VisitExpr(const Expr& expr) final {
+    if (expr.as<VarNode>()) {
+      // Do not memoize vars because shape functions could use either the data
+      // or the shape of a var each time.
+      return ExprFunctor::VisitExpr(expr);
+    }
+    // For other case, do memoized visit
+    return backend::MemoizedExprTranslator<Array<te::Tensor>>::VisitExpr(expr);
+  }
+
+  Array<te::Tensor> VisitExpr_(const VarNode* var_node) final {
+    auto var = GetRef<Var>(var_node);
+    auto it = param_states_.find(var);
+    if (it == param_states_.end()) {
+      LOG(FATAL) << "Free variable " << var->name_hint();

Review comment:
       ```suggestion
         LOG(FATAL) << "Unexpected free variable " << var->name_hint();
   ```

##########
File path: src/relay/backend/te_compiler_cache.cc
##########
@@ -0,0 +1,681 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "./te_compiler_cache.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+TVM_REGISTER_NODE_TYPE(LoweredOutputNode);
+TVM_REGISTER_NODE_TYPE(CachedFuncNode);
+TVM_REGISTER_NODE_TYPE(CCacheKeyNode);
+TVM_REGISTER_NODE_TYPE(CCacheValueNode);
+
+LoweredOutput::LoweredOutput(tvm::Array<te::Tensor> outputs, OpImplementation impl) {
+  auto n = make_object<LoweredOutputNode>();
+  n->outputs = std::move(outputs);
+  n->implementation = std::move(impl);
+  data_ = std::move(n);
+}
+
+CCacheKey::CCacheKey(Function source_func, Target target) {
+  auto n = make_object<CCacheKeyNode>();
+  n->source_func = std::move(source_func);
+  n->target = std::move(target);
+  data_ = std::move(n);
+}
+
+CachedFunc::CachedFunc(tvm::Target target, GlobalVar prim_fn_var, tvm::Array<te::Tensor> inputs,
+                       tvm::Array<te::Tensor> outputs, te::Schedule schedule,
+                       tvm::Array<Integer> shape_func_param_states, IRModule funcs) {
+  auto n = make_object<CachedFuncNode>();
+  n->target = target;
+  n->prim_fn_var = prim_fn_var;
+  n->inputs = inputs;
+  n->outputs = outputs;
+  n->schedule = schedule;
+  n->shape_func_param_states = shape_func_param_states;
+  n->funcs = funcs;
+  data_ = std::move(n);
+}
+
+Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
+  // for now, we always use int32 shape when possible
+  // even if the result of shape inference becomes int64.
+  Array<IndexExpr> res;
+  for (IndexExpr val : shape) {
+    const int64_t* pval = tir::as_const_int(val);
+    if (pval != nullptr) {
+#ifndef TVM_INDEX_DEFAULT_I64
+      ICHECK_LE(pval[0], std::numeric_limits<int32_t>::max());
+      ICHECK_GE(pval[0], std::numeric_limits<int32_t>::min());
+      res.push_back(IntImm(DataType::Int(32), *pval));
+#else
+      res.push_back(val);
+#endif  // TVM_INDEX_DEFAULT_I64
+    } else if (val->IsInstance<tir::AnyNode>()) {
+      res.push_back(val.as<tir::AnyNode>()->ToVar());
+    } else {
+      res.push_back(val);
+    }
+  }
+  return res;
+}
+
+// The getter to get schedule from compile engine.
+// Get schedule from functor.
+class ScheduleGetter : public backend::MemoizedExprTranslator<Array<te::Tensor>> {
+ public:
+  explicit ScheduleGetter(Target target)
+      : target_(target), device_copy_op_(Op::Get("device_copy")) {
+    // Whether to use auto_scheduler schedule.
+    use_auto_scheduler_ = backend::IsAutoSchedulerEnabled();
+  }
+
+  CachedFunc Create(const Function& prim_func, std::function<std::string(std::string)> renamer) {
+    Array<tvm::te::Tensor> fn_inputs;
+    for (Var param : prim_func->params) {
+      Array<tvm::te::Tensor> inputs;
+      if (const auto* ttype = param->checked_type().as<TensorTypeNode>()) {
+        tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype);
+        fn_inputs.push_back(tensor);
+        inputs.push_back(tensor);
+      } else {
+        // flatten tuple of tensor type.
+        const auto* tuple_type = param->type_as<TupleTypeNode>();
+        for (Type field : tuple_type->fields) {
+          const auto* ttype = field.as<TensorTypeNode>();
+          // TODO(@icemelon): Allow recursive tuple
+          ICHECK(ttype != nullptr);
+          tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype);
+          fn_inputs.push_back(tensor);
+          inputs.push_back(tensor);
+        }
+      }
+      memo_[param] = inputs;
+    }
+    readable_name_stream_ << "fused";
+    auto outputs = this->VisitExpr(prim_func->body);
+    auto candidate_name = readable_name_stream_.str();
+    constexpr static size_t kMaxFuncNameLength = 80;
+    if (candidate_name.size() > kMaxFuncNameLength) {
+      std::stringstream truncated_name;
+      truncated_name << candidate_name.substr(0, kMaxFuncNameLength);
+      truncated_name << "_" << std::hash<std::string>{}(candidate_name) << "_";
+      candidate_name = truncated_name.str();
+    }
+
+    ICHECK(anchor_op_.defined());
+    // Fusion over tupled results may leave identity relationships
+    // between inputs and outputs, and those should not be scheduled.
+    // Hence schedule only non PlaceholderOp outputs.
+    tvm::Array<te::Tensor> tensor_outs;
+    for (const auto& tensor : outputs) {
+      if (!tensor->op.as<te::PlaceholderOpNode>()) {
+        tensor_outs.push_back(tensor);
+      }
+    }
+
+    te::Schedule schedule;
+    // No need to register schedule for device copy op.
+    if (anchor_attrs_.as<DeviceCopyAttrs>() == nullptr) {
+      if (use_auto_scheduler_) {
+        const auto* fauto_schedule =
+            runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute");
+        ICHECK(fauto_schedule != nullptr)
+            << "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered";
+        ObjectRef obj = (*fauto_schedule)(tensor_outs);
+        if (obj.defined()) {
+          schedule = Downcast<te::Schedule>(obj);
+        }
+      }
+
+      // Use TOPI schdule if user specificed, or the function has no auto_scheduler schedule.
+      if (!schedule.defined()) {
+        ICHECK(anchor_implementation_.defined());
+        schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_);
+      }
+      for (const auto& scalar : scalars_) {
+        if (schedule->Contain(scalar)) {
+          schedule[scalar].compute_inline();
+        }
+      }
+    }
+
+    auto prim_fn_var = GlobalVar(candidate_name);
+    prim_fn_var->checked_type_ = prim_func->checked_type();
+    return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, {});
+  }
+
+  Array<te::Tensor> VisitExpr_(const VarNode* op) final {
+    LOG(FATAL) << "Free variable " << op->name_hint();
+    return {};
+  }
+
+  Array<te::Tensor> VisitExpr_(const ConstantNode* op) final {
+    using tir::make_const;
+    ICHECK(op->is_scalar());
+    void* data = op->data->data;
+    DataType dtype = DataType(op->data->dtype);
+    auto value = te::compute(
+        {},
+        [&](const Array<tvm::tir::Var>&) {
+          if (dtype == DataType::Int(32)) {
+            return make_const(dtype, static_cast<const int32_t*>(data)[0]);
+          } else if (dtype == DataType::Int(64)) {
+            return make_const(dtype, static_cast<const int64_t*>(data)[0]);
+          } else if (dtype == DataType::Float(32)) {
+            return make_const(dtype, static_cast<const float*>(data)[0]);
+          } else if (dtype == DataType::Float(64)) {
+            return make_const(dtype, static_cast<const double*>(data)[0]);
+          } else if (dtype == DataType::Bool()) {
+            return make_const(dtype, static_cast<const uint8_t*>(data)[0]);
+          } else {
+            LOG(FATAL) << "not handled";
+            return tvm::PrimExpr();
+          }
+        },
+        "compile_engine_const", topi::kBroadcast);
+    scalars_.push_back(value->op);
+    return {value};
+  }
+
+  Array<te::Tensor> VisitExpr_(const CallNode* call_node) final {
+    static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
+    static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call");
+    ICHECK(flower_call) << "relay.backend.lower_call is not registered.";
+
+    Array<te::Tensor> inputs;
+    int count_tuple = 0;
+    for (Expr arg : call_node->args) {
+      if (arg->checked_type().as<TupleTypeNode>()) {
+        ++count_tuple;
+      }
+      for (te::Tensor tensor : VisitExpr(arg)) {
+        inputs.push_back(tensor);
+      }
+    }
+    if (count_tuple) {
+      ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input";
+    }
+
+    ICHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops";
+    Op op = Downcast<Op>(call_node->op);
+
+    Array<te::Tensor> outputs;
+    OpImplementation impl;
+    // Skip fcompute for device copy operators as it is not registered.
+    if (op == device_copy_op_) {
+      const auto* copy_input = inputs[0].operator->();
+      outputs.push_back(te::Tensor(copy_input->shape, copy_input->dtype, te::Operation(), 0));
+    } else {
+      LoweredOutput lowered_out = (*flower_call)(GetRef<Call>(call_node), inputs, target_);
+      outputs = lowered_out->outputs;
+      impl = lowered_out->implementation;
+    }
+
+    int op_pattern = fpattern[op];
+    if (!use_auto_scheduler_ && op_pattern >= kCommReduce) {
+      ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce)
+          << "Cannot apply TOPI schedule to a primitive function with two complicated ops"
+          << " anchor=" << anchor_op_ << " current=" << op;
+    }
+    if (op_pattern >= anchor_op_pattern_) {
+      anchor_op_ = op;
+      anchor_attrs_ = call_node->attrs;
+      anchor_op_pattern_ = op_pattern;
+      anchor_implementation_ = impl;
+    }
+    if (outputs.size() != 1) {
+      const auto* tuple_type = call_node->checked_type().as<TupleTypeNode>();
+      ICHECK(tuple_type) << "Expect output to be a tuple type";
+      ICHECK_EQ(tuple_type->fields.size(), outputs.size());
+    }
+    // Set the name to `__copy`. It will be detected in graph runtime to perform
+    // data copy across devices.
+    if (op == device_copy_op_) {
+      readable_name_stream_.str(std::string());
+      readable_name_stream_ << "__copy";
+    } else {
+      readable_name_stream_ << '_' << op->name;
+    }
+    return outputs;
+  }
+
+  Array<te::Tensor> VisitExpr_(const FunctionNode* op) final {
+    LOG(FATAL) << "Do not support sub function";
+    return Array<te::Tensor>();
+  }
+
+  Array<te::Tensor> VisitExpr_(const LetNode* op) final {
+    Array<te::Tensor> val = VisitExpr(op->value);
+    ICHECK(!memo_.count(op->var));
+    memo_[op->var] = val;
+    return VisitExpr(op->body);
+  }
+
+  Array<te::Tensor> VisitExpr_(const TupleNode* op) final {
+    Array<te::Tensor> fields;
+    for (Expr field : op->fields) {
+      ICHECK(field->checked_type().as<TensorTypeNode>()) << "Only allow Tuple of Tensor";

Review comment:
       ```suggestion
         ICHECK(field->checked_type().as<TensorTypeNode>()) << "Expected a Tuple of Tensor, but got " << field->checked_type()->GetTypeName();
   ```

##########
File path: src/relay/ir/function.cc
##########
@@ -62,9 +62,12 @@ TVM_REGISTER_GLOBAL("relay.ir.Function")
 
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     .set_dispatch<FunctionNode>([](const ObjectRef& ref, ReprPrinter* p) {
-      auto* node = static_cast<const FunctionNode*>(ref.get());
-      p->stream << "FunctionNode(" << node->params << ", " << node->ret_type << ", " << node->body
-                << ", " << node->type_params << ", " << node->attrs << ")";
+      // auto* node = static_cast<const FunctionNode*>(ref.get());
+      // p->stream << "FunctionNode(" << node->params << ", " << node->ret_type << ", " <<
+      // node->body
+      //           << ", " << node->type_params << ", " << node->attrs << ")";

Review comment:
       remove?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] csullivan commented on a change in pull request #7518: [RFC]TECompiler: Staged refactor and removal of compile engine

Posted by GitBox <gi...@apache.org>.
csullivan commented on a change in pull request #7518:
URL: https://github.com/apache/tvm/pull/7518#discussion_r608872546



##########
File path: src/relay/backend/graph_executor_codegen.cc
##########
@@ -349,65 +380,47 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
     return AddNode(node, GetRef<Expr>(op));
   }
 
-  std::vector<GraphNodeRef> VisitExpr_(const CallNode* op) override {
-    Expr expr = GetRef<Expr>(op);
-    Function func;
-    if (op->op.as<OpNode>()) {
-      LOG(FATAL) << "Operators should be transformed away; try applying"
-                 << "the fuse_ops transformation to the expression.";
-    } else if (op->op.as<GlobalVarNode>()) {
-      LOG(FATAL) << "Not implemented";
-    } else if (op->op.as<FunctionNode>()) {
-      func = GetRef<Function>(op->op.as<FunctionNode>());
-    } else {
-      LOG(FATAL) << "TVM runtime does not support calls to " << op->op->GetTypeKey();
-    }
-    if (!func->HasNonzeroAttr(attr::kPrimitive)) {
-      LOG(FATAL) << "TVM only support calls to primitive functions "
-                 << "(i.e functions composed of fusable operator invocations)";
-    }
+  std::vector<GraphNodeRef> VisitExpr_(const CallNode* call_node) override {
+    relay::Call call = GetRef<Call>(call_node);
+    if (auto global_node = call->op.as<GlobalVarNode>()) {
 
-    auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey");
-    auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower");
-    Target target;
-    // Handle external function
-    if (func->GetAttr<String>(attr::kCompiler).defined()) {
-      target = Target("ext_dev");
-      CCacheKey key = (*pf0)(func, target);
-      CachedFunc ext_func = (*pf1)(compile_engine_, key);
-      ICHECK(ext_func.defined()) << "External function is not defined.";
-      UpdateConstants(func, &params_);
-      return GraphAddCallNode(op, ext_func->func_name, ext_func->func_name);
-    }
+      auto prim_fn_name = global_node->name_hint;
 
-    ICHECK_GE(storage_device_map_.count(expr), 0);
-    auto& device_type = storage_device_map_[expr][1];
-    auto call_dev_type = device_type[0]->value;
-    // Normal Relay Function
-    if (targets_.size() == 1) {
-      // homogeneous execution.
-      const auto& it = targets_.begin();
-      target = (*it).second;
-    } else {
-      // heterogeneous execution.
-      std::string call_dev_name;
-      if (call_dev_type == 0) {
-        call_dev_name = "llvm";
+      Target target;
+
+      // // Handle external function
+      // if (func->GetAttr<String>(attr::kCompiler).defined()) {
+      //   UpdateConstants(func, &params_);
+      //   return GraphAddCallNode(call_node, prim_fn_name, prim_fn_name);

Review comment:
       We still need emit a call for the external function, no?

##########
File path: src/relay/backend/graph_executor_codegen.cc
##########
@@ -33,10 +33,15 @@
 #include <vector>
 
 #include "compile_engine.h"
+#include "te_compiler.h"
 #include "utils.h"
 
 namespace tvm {
 namespace relay {
+
+/// TODO(@jroesch, @chris): declare directly elsewhere

Review comment:
       ```suggestion
   /// TODO(@jroesch, @csullivan): declare directly elsewhere
   ```

##########
File path: tests/python/relay/test_backend_graph_executor.py
##########
@@ -126,6 +126,7 @@ def test_plan_memory():
     func = relay.Function([x, y], z)
     mod = tvm.IRModule.from_expr(func)
     mod = relay.transform.InferType()(mod)
+    print(mod)

Review comment:
       ```suggestion
   ```

##########
File path: tests/python/relay/test_backend_graph_executor.py
##########
@@ -231,10 +232,12 @@ def test_graph_executor_nested_tuples():
 
 
 if __name__ == "__main__":
-    test_plan_memory()
-    test_with_params()
-    test_add_op_scalar()
     test_add_op_tensor()
-    test_add_op_broadcast()
-    test_gru_like()
-    test_compile_nested_tuples()
+    # test_plan_memory()
+    # test_with_params()
+    # test_add_op_scalar()
+    # test_add_op_tensor()
+    # test_add_op_broadcast()
+    # test_gru_like()
+    # test_compile_nested_tuples()
+    # test_add_op_tensor()

Review comment:
       Shouldn't need to disable these,
   ```suggestion
       test_plan_memory()
       test_with_params()
       test_add_op_scalar()
       test_add_op_tensor()
       test_add_op_broadcast()
       test_gru_like()
       test_compile_nested_tuples()
       test_add_op_tensor()
   ```

##########
File path: src/relay/backend/compile_engine.cc
##########
@@ -731,50 +170,50 @@ class CompileEngineImpl : public CompileEngineNode {
     // No need to lower external functions for now. We will invoke the external
     // codegen tool once and lower all functions together.
     if (key->source_func->GetAttr<String>(attr::kCompiler).defined()) {
-      auto cache_node = make_object<CachedFuncNode>();
+      auto ir_module = IRModule();
       const auto name_node = key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
       ICHECK(name_node.defined()) << "External function has not been attached a name yet.";
-      cache_node->func_name = std::string(name_node.value());
-      cache_node->target = Target("ext_dev");
-      cache_node->funcs->Add(GlobalVar(cache_node->func_name), key->source_func);
-      value->cached_func = CachedFunc(cache_node);
+      auto func_name = std::string(name_node.value());
+      auto target = Target("ext_dev");
+      auto global_var = GlobalVar(func_name);
+      global_var->checked_type_ = key->source_func->checked_type();
+      ir_module->Add(global_var, key->source_func);
+      value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule(), {}, ir_module);
       return value;
     }
+
     // Enforce use the target.
     With<Target> target_scope(key->target);
 
     ICHECK(!value->cached_func.defined());
-    auto cfunc = CreateSchedule(key->source_func, key->target);
-    auto cache_node = make_object<CachedFuncNode>(*(cfunc.operator->()));
+    auto cfunc = PrimFuncFor(key->source_func, key->target,
+                             [&](std::string name) { return GetUniqueName(name, name_map_); });
 
     // Skip lowering for device copy node.
     const Expr body = (key->source_func)->body;
     if (const CallNode* call_node = body.as<CallNode>()) {
       if (call_node->attrs.as<DeviceCopyAttrs>()) {
-        value->cached_func = CachedFunc(cache_node);
+        value->cached_func = cfunc;
         return value;
       }
     }
 
-    cache_node->func_name = GetUniqueName(cache_node->func_name);
     // NOTE: array will copy on write.
-    Array<te::Tensor> all_args = cache_node->inputs;
-    for (te::Tensor arg : cache_node->outputs) {
+    Array<te::Tensor> all_args = Array<te::Tensor>(cfunc->inputs);
+    for (te::Tensor arg : cfunc->outputs) {
       all_args.push_back(arg);
     }
-    // lower the function
-    if (const auto* f = runtime::Registry::Get("relay.backend.lower")) {
-      cache_node->funcs = (*f)(cfunc->schedule, all_args, cache_node->func_name, key->source_func);
-    } else {
-      using tvm::transform::PassContext;
-      With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
 
-      std::unordered_map<te::Tensor, tir::Buffer> binds;
-      cache_node->funcs = tvm::lower(cfunc->schedule, all_args, cache_node->func_name, binds);
-    }
-    value->cached_func = CachedFunc(cache_node);
+    using tvm::transform::PassContext;
+    With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
+
+    std::unordered_map<te::Tensor, tir::Buffer> binds;
+    auto func_name = cfunc->prim_fn_var->name_hint;
+    cfunc->funcs->Update(tvm::lower(cfunc->schedule, all_args, func_name, binds));

Review comment:
       It looks like the compiler has been relying on `if (const auto* f = runtime::Registry::Get("relay.backend.lower")) {` for long enough that the set of passes run in the python impl of tvm.lower (in python/driver/build_module.py) is different than the c++ `tvm::lower`. 
   
   My feeling is that we should use the same flow as before to avoid too much disruption and then refactor after this lands. 

##########
File path: src/runtime/graph_executor/graph_executor.cc
##########
@@ -372,6 +372,7 @@ void GraphExecutor::SetupOpExecs() {
       uint32_t eid = this->entry_id(e);
       args.push_back(*(data_entry_[eid].operator->()));
     }
+

Review comment:
       ```suggestion
   ```

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -0,0 +1,372 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "te_compiler.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "te_compiler_cache.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+using namespace tvm::relay::transform;
+
+TVM_REGISTER_OBJECT_TYPE(TECompilerNode);
+
+class TECompilerImpl : public TECompilerNode {
+ public:
+  // Lower the function.
+  CachedFunc Lower(const CCacheKey& key) { return LowerInternal(key)->cached_func; }
+
+  // For now, build one module per function.
+  PackedFunc JIT(const CCacheKey& key) final {
+    CCacheValue value = LowerInternal(key);
+    if (value->packed_func != nullptr) {
+      return value->packed_func;
+    }
+    auto m = build(value->cached_func->funcs, key->target, Target(nullptr));
+    value->packed_func = m.GetFunction(value->cached_func->prim_fn_var->name_hint);
+    return value->packed_func;
+  }
+
+  CachedFunc LowerShapeFunc(const CCacheKey& key) final {
+    return LowerShapeFuncInternal(key)->cached_func;
+  }
+
+  Map<String, IRModule> GetLoweredFunctions() {
+    Map<String, IRModule> lowered_functions;
+    for (const auto& it : cache_) {
+      auto source_func = it.first;
+      auto lowered_func = it.second;
+      auto target = source_func->target;
+
+      if (!lowered_functions.count(target->str())) {
+        lowered_functions.Set(target->str(), IRModule(Map<GlobalVar, BaseFunc>({})));
+      }
+
+      lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs);
+    }
+    return lowered_functions;
+  }
+
+  Array<tvm::runtime::Module> LowerExternalFunctions() {
+    Array<tvm::runtime::Module> ret;
+    std::unordered_map<std::string, std::string> cached_symbol;
+    std::vector<CCacheKey> cached_ext_funcs;
+    for (const auto& it : cache_) {
+      auto src_func = it.first->source_func;
+      ICHECK(src_func.defined());
+      if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
+        auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
+        ICHECK(code_gen.defined()) << "No external codegen is set";
+        std::string code_gen_name = code_gen.value();
+        cached_ext_funcs.push_back(it.first);
+
+        auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+        ICHECK(symbol_name.defined()) << "No external symbol is set for:\n"
+                                      << AsText(src_func, false);
+
+        std::string sn = symbol_name.value();
+        if (cached_symbol.count(sn)) {
+          cached_symbol[sn] = code_gen_name;
+        } else {
+          ICHECK_NE(sn, code_gen_name)
+              << "Found duplicated symbol: " << sn << " for: " << code_gen_name;
+        }
+
+        std::string ext_name = "relay.ext." + code_gen_name;
+        auto pf = tvm::runtime::Registry::Get(ext_name);
+        ICHECK(pf) << "Failed to find the codegen tool for " << ext_name << "\n";
+        // No need to keep compiler attribute at this point, functions have been
+        // extracted for specific codegen.
+        src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue<ObjectRef>());
+        runtime::Module ext_mod = (*pf)(src_func);
+
+        ICHECK(ext_mod.defined()) << "No external runtime is generated.";
+        ret.push_back(ext_mod);
+      }
+    }
+
+    // No need to cache external functions as we collected them all to create
+    // external runtime modules.
+    for (const auto& it : cached_ext_funcs) {
+      cache_.erase(it);
+    }
+    return ret;
+  }
+
+  void Clear() final { cache_.clear(); }
+
+  // List all items in the cache.
+  Array<ObjectRef> ListItems() {
+    std::lock_guard<std::mutex> lock(mutex_);
+    Array<ObjectRef> items;
+    for (auto& kv : cache_) {
+      items.push_back(kv.first);
+      items.push_back(kv.second);
+    }
+    return items;
+  }
+
+  /*!
+   * \brief Get the cache key of the function that is being lowered currently
+   * \return the cache key
+   */
+  CCacheKey GetCurrentCCacheKey() { return cur_ccache_key_; }
+
+ private:
+  // implement lowered func
+  CCacheValue LowerInternal(const CCacheKey& key) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    CCacheValue value;
+    auto it = cache_.find(key);
+    if (it != cache_.end()) {
+      it->second->use_count += 1;
+      if (it->second->cached_func.defined()) return it->second;
+      value = it->second;
+    } else {
+      value = CCacheValue(make_object<CCacheValueNode>());
+      value->use_count = 0;
+      if (!backend::IsCompileEngineCacheDisabled()) {
+        cache_[key] = value;
+      }
+    }
+    cur_ccache_key_ = key;
+
+    // No need to lower external functions for now. We will invoke the external
+    // codegen tool once and lower all functions together.
+    if (key->source_func->GetAttr<String>(attr::kCompiler).defined()) {
+      auto ir_module = IRModule();
+      const auto name_node = key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+      ICHECK(name_node.defined()) << "External function has not been attached a name yet.";
+      auto func_name = std::string(name_node.value());
+      auto target = Target("ext_dev");
+      auto global_var = GlobalVar(func_name);
+      global_var->checked_type_ = key->source_func->checked_type();
+      ir_module->Add(global_var, key->source_func);
+      value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule(), {}, ir_module);
+      return value;
+    }
+    // Enforce use the target.
+    With<Target> target_scope(key->target);
+
+    ICHECK(!value->cached_func.defined());
+    auto cfunc = PrimFuncFor(key->source_func, key->target,
+                             [&](std::string name) { return GetUniqueName(name, name_map_); });
+
+    // Skip lowering for device copy node.
+    const Expr body = (key->source_func)->body;
+    if (const CallNode* call_node = body.as<CallNode>()) {
+      if (call_node->attrs.as<DeviceCopyAttrs>()) {
+        value->cached_func = cfunc;
+        return value;
+      }
+    }
+
+    std::cout << "Input Size: " << cfunc->inputs.size() << std::endl;
+    std::cout << "Output Size: " << cfunc->outputs.size() << std::endl;

Review comment:
       ```suggestion
   ```

##########
File path: src/relay/backend/graph_executor_codegen.cc
##########
@@ -181,23 +186,56 @@ class GraphOpNode : public GraphNode {
   const std::string op_type_name_{"tvm_op"};
 };
 
-/*! \brief Code generator for graph executor */
+/*! \brief Code generator for the graph executor, produces a module containing the graph JSON,
+ * module, and parameters.
+ */
 class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<GraphNodeRef>> {
  public:
   GraphExecutorCodegen(runtime::Module* mod, const TargetsMap& targets) : mod_(mod) {
-    compile_engine_ = CompileEngine::Global();
     targets_ = targets;
   }
 
   LoweredOutput Codegen(relay::Function func) {
-    auto pf = GetPackedFunc("relay.backend.GraphPlanMemory");
-    storage_device_map_ = (*pf)(func);
+    // Jared: why do we do this? just call C++ API.
+    // auto pf = GetPackedFunc("relay.backend.GraphPlanMemory");
+    // storage_device_map_ = (*pf)(func);
+    storage_device_map_ = GraphPlanMemory(func);

Review comment:
       Likely it's better to run the CollectDeviceInfo (and future CollectStorageInfo) pass prior to GraphPlanMemory, and then the collected info provided to the memory planner. In that case we would not need to invoke the memory planner twice, and instead do memory planning only after we have lowered everything.
   
   ```suggestion
       // TODO(csullivan): Consider refactoring CollectDeviceInfo out of memory
       // planner and provide as an argument for use during planning. In this way
       // we can avoid running the memory planner prior to lowering and do so only
       // thereafter.
       storage_device_map_ = GraphPlanMemory(func);
   ```

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -0,0 +1,372 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "te_compiler.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "te_compiler_cache.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+using namespace tvm::relay::transform;
+
+TVM_REGISTER_OBJECT_TYPE(TECompilerNode);
+
+class TECompilerImpl : public TECompilerNode {
+ public:
+  // Lower the function.
+  CachedFunc Lower(const CCacheKey& key) { return LowerInternal(key)->cached_func; }
+
+  // For now, build one module per function.
+  PackedFunc JIT(const CCacheKey& key) final {
+    CCacheValue value = LowerInternal(key);
+    if (value->packed_func != nullptr) {
+      return value->packed_func;
+    }
+    auto m = build(value->cached_func->funcs, key->target, Target(nullptr));
+    value->packed_func = m.GetFunction(value->cached_func->prim_fn_var->name_hint);
+    return value->packed_func;
+  }
+
+  CachedFunc LowerShapeFunc(const CCacheKey& key) final {
+    return LowerShapeFuncInternal(key)->cached_func;
+  }
+
+  Map<String, IRModule> GetLoweredFunctions() {
+    Map<String, IRModule> lowered_functions;
+    for (const auto& it : cache_) {
+      auto source_func = it.first;
+      auto lowered_func = it.second;
+      auto target = source_func->target;
+
+      if (!lowered_functions.count(target->str())) {
+        lowered_functions.Set(target->str(), IRModule(Map<GlobalVar, BaseFunc>({})));
+      }
+
+      lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs);
+    }
+    return lowered_functions;
+  }
+
+  Array<tvm::runtime::Module> LowerExternalFunctions() {
+    Array<tvm::runtime::Module> ret;
+    std::unordered_map<std::string, std::string> cached_symbol;
+    std::vector<CCacheKey> cached_ext_funcs;
+    for (const auto& it : cache_) {
+      auto src_func = it.first->source_func;
+      ICHECK(src_func.defined());
+      if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
+        auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
+        ICHECK(code_gen.defined()) << "No external codegen is set";
+        std::string code_gen_name = code_gen.value();
+        cached_ext_funcs.push_back(it.first);
+
+        auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+        ICHECK(symbol_name.defined()) << "No external symbol is set for:\n"
+                                      << AsText(src_func, false);
+
+        std::string sn = symbol_name.value();
+        if (cached_symbol.count(sn)) {
+          cached_symbol[sn] = code_gen_name;
+        } else {
+          ICHECK_NE(sn, code_gen_name)
+              << "Found duplicated symbol: " << sn << " for: " << code_gen_name;
+        }
+
+        std::string ext_name = "relay.ext." + code_gen_name;
+        auto pf = tvm::runtime::Registry::Get(ext_name);
+        ICHECK(pf) << "Failed to find the codegen tool for " << ext_name << "\n";
+        // No need to keep compiler attribute at this point, functions have been
+        // extracted for specific codegen.
+        src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue<ObjectRef>());
+        runtime::Module ext_mod = (*pf)(src_func);
+
+        ICHECK(ext_mod.defined()) << "No external runtime is generated.";
+        ret.push_back(ext_mod);
+      }
+    }
+
+    // No need to cache external functions as we collected them all to create
+    // external runtime modules.
+    for (const auto& it : cached_ext_funcs) {
+      cache_.erase(it);
+    }
+    return ret;
+  }
+
+  void Clear() final { cache_.clear(); }
+
+  // List all items in the cache.
+  Array<ObjectRef> ListItems() {
+    std::lock_guard<std::mutex> lock(mutex_);
+    Array<ObjectRef> items;
+    for (auto& kv : cache_) {
+      items.push_back(kv.first);
+      items.push_back(kv.second);
+    }
+    return items;
+  }
+
+  /*!
+   * \brief Get the cache key of the function that is being lowered currently
+   * \return the cache key
+   */
+  CCacheKey GetCurrentCCacheKey() { return cur_ccache_key_; }
+
+ private:
+  // implement lowered func
+  CCacheValue LowerInternal(const CCacheKey& key) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    CCacheValue value;
+    auto it = cache_.find(key);
+    if (it != cache_.end()) {
+      it->second->use_count += 1;
+      if (it->second->cached_func.defined()) return it->second;
+      value = it->second;
+    } else {
+      value = CCacheValue(make_object<CCacheValueNode>());
+      value->use_count = 0;
+      if (!backend::IsCompileEngineCacheDisabled()) {
+        cache_[key] = value;
+      }
+    }
+    cur_ccache_key_ = key;
+
+    // No need to lower external functions for now. We will invoke the external
+    // codegen tool once and lower all functions together.
+    if (key->source_func->GetAttr<String>(attr::kCompiler).defined()) {
+      auto ir_module = IRModule();
+      const auto name_node = key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+      ICHECK(name_node.defined()) << "External function has not been attached a name yet.";
+      auto func_name = std::string(name_node.value());
+      auto target = Target("ext_dev");
+      auto global_var = GlobalVar(func_name);
+      global_var->checked_type_ = key->source_func->checked_type();
+      ir_module->Add(global_var, key->source_func);
+      value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule(), {}, ir_module);
+      return value;
+    }
+    // Enforce use the target.
+    With<Target> target_scope(key->target);
+
+    ICHECK(!value->cached_func.defined());
+    auto cfunc = PrimFuncFor(key->source_func, key->target,
+                             [&](std::string name) { return GetUniqueName(name, name_map_); });
+
+    // Skip lowering for device copy node.
+    const Expr body = (key->source_func)->body;
+    if (const CallNode* call_node = body.as<CallNode>()) {
+      if (call_node->attrs.as<DeviceCopyAttrs>()) {
+        value->cached_func = cfunc;
+        return value;
+      }
+    }
+
+    std::cout << "Input Size: " << cfunc->inputs.size() << std::endl;
+    std::cout << "Output Size: " << cfunc->outputs.size() << std::endl;
+    // NOTE: array will copy on write.
+    Array<te::Tensor> all_args = Array<te::Tensor>(cfunc->inputs);
+    for (te::Tensor arg : cfunc->outputs) {
+      all_args.push_back(arg);
+    }
+
+    std::cout << "Allargs Size: " << all_args.size() << std::endl;
+
+    using tvm::transform::PassContext;
+    With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
+
+    std::unordered_map<te::Tensor, tir::Buffer> binds;
+    auto func_name = cfunc->prim_fn_var->name_hint;
+    cfunc->funcs->Update(tvm::lower(cfunc->schedule, all_args, func_name, binds));

Review comment:
       Same comment as above, 
   > It looks like the compiler has been relying on `if (const auto* f = runtime::Registry::Get("relay.backend.lower")) {` for long enough that the set of passes run in the python impl of tvm.lower (in python/driver/build_module.py) is different than the c++ `tvm::lower`. 
   > 
   > My feeling is that we should use the same flow as before to avoid too much disruption and then refactor after this lands. 

##########
File path: src/runtime/graph_executor/graph_executor.cc
##########
@@ -433,6 +434,7 @@ GraphExecutor::CreateTVMOp(const TVMOpParam& param, const std::vector<DLTensor>&
   ICHECK(pf != nullptr) << "no such function in module: " << param.func_name;
 
   auto fexec = [arg_ptr, pf]() {
+    std::cout << "Number of args: " << static_cast<int>(arg_ptr->arg_values.size());

Review comment:
       ```suggestion
   ```

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -0,0 +1,372 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "te_compiler.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "te_compiler_cache.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+using namespace tvm::relay::transform;
+
+TVM_REGISTER_OBJECT_TYPE(TECompilerNode);
+
+class TECompilerImpl : public TECompilerNode {
+ public:
+  // Lower the function.
+  CachedFunc Lower(const CCacheKey& key) { return LowerInternal(key)->cached_func; }
+
+  // For now, build one module per function.
+  PackedFunc JIT(const CCacheKey& key) final {
+    CCacheValue value = LowerInternal(key);
+    if (value->packed_func != nullptr) {
+      return value->packed_func;
+    }
+    auto m = build(value->cached_func->funcs, key->target, Target(nullptr));
+    value->packed_func = m.GetFunction(value->cached_func->prim_fn_var->name_hint);
+    return value->packed_func;
+  }
+
+  CachedFunc LowerShapeFunc(const CCacheKey& key) final {
+    return LowerShapeFuncInternal(key)->cached_func;
+  }
+
+  Map<String, IRModule> GetLoweredFunctions() {
+    Map<String, IRModule> lowered_functions;
+    for (const auto& it : cache_) {
+      auto source_func = it.first;
+      auto lowered_func = it.second;
+      auto target = source_func->target;
+
+      if (!lowered_functions.count(target->str())) {
+        lowered_functions.Set(target->str(), IRModule(Map<GlobalVar, BaseFunc>({})));
+      }
+
+      lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs);
+    }
+    return lowered_functions;
+  }
+
+  Array<tvm::runtime::Module> LowerExternalFunctions() {
+    Array<tvm::runtime::Module> ret;
+    std::unordered_map<std::string, std::string> cached_symbol;
+    std::vector<CCacheKey> cached_ext_funcs;
+    for (const auto& it : cache_) {
+      auto src_func = it.first->source_func;
+      ICHECK(src_func.defined());
+      if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
+        auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
+        ICHECK(code_gen.defined()) << "No external codegen is set";
+        std::string code_gen_name = code_gen.value();
+        cached_ext_funcs.push_back(it.first);
+
+        auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+        ICHECK(symbol_name.defined()) << "No external symbol is set for:\n"
+                                      << AsText(src_func, false);
+
+        std::string sn = symbol_name.value();
+        if (cached_symbol.count(sn)) {
+          cached_symbol[sn] = code_gen_name;
+        } else {
+          ICHECK_NE(sn, code_gen_name)
+              << "Found duplicated symbol: " << sn << " for: " << code_gen_name;
+        }
+
+        std::string ext_name = "relay.ext." + code_gen_name;
+        auto pf = tvm::runtime::Registry::Get(ext_name);
+        ICHECK(pf) << "Failed to find the codegen tool for " << ext_name << "\n";
+        // No need to keep compiler attribute at this point, functions have been
+        // extracted for specific codegen.
+        src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue<ObjectRef>());
+        runtime::Module ext_mod = (*pf)(src_func);
+
+        ICHECK(ext_mod.defined()) << "No external runtime is generated.";
+        ret.push_back(ext_mod);
+      }
+    }
+
+    // No need to cache external functions as we collected them all to create
+    // external runtime modules.
+    for (const auto& it : cached_ext_funcs) {
+      cache_.erase(it);
+    }
+    return ret;
+  }
+
+  void Clear() final { cache_.clear(); }
+
+  // List all items in the cache.
+  Array<ObjectRef> ListItems() {
+    std::lock_guard<std::mutex> lock(mutex_);
+    Array<ObjectRef> items;
+    for (auto& kv : cache_) {
+      items.push_back(kv.first);
+      items.push_back(kv.second);
+    }
+    return items;
+  }
+
+  /*!
+   * \brief Get the cache key of the function that is being lowered currently
+   * \return the cache key
+   */
+  CCacheKey GetCurrentCCacheKey() { return cur_ccache_key_; }
+
+ private:
+  // implement lowered func
+  CCacheValue LowerInternal(const CCacheKey& key) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    CCacheValue value;
+    auto it = cache_.find(key);
+    if (it != cache_.end()) {
+      it->second->use_count += 1;
+      if (it->second->cached_func.defined()) return it->second;
+      value = it->second;
+    } else {
+      value = CCacheValue(make_object<CCacheValueNode>());
+      value->use_count = 0;
+      if (!backend::IsCompileEngineCacheDisabled()) {
+        cache_[key] = value;
+      }
+    }
+    cur_ccache_key_ = key;
+
+    // No need to lower external functions for now. We will invoke the external
+    // codegen tool once and lower all functions together.
+    if (key->source_func->GetAttr<String>(attr::kCompiler).defined()) {
+      auto ir_module = IRModule();
+      const auto name_node = key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+      ICHECK(name_node.defined()) << "External function has not been attached a name yet.";
+      auto func_name = std::string(name_node.value());
+      auto target = Target("ext_dev");
+      auto global_var = GlobalVar(func_name);
+      global_var->checked_type_ = key->source_func->checked_type();
+      ir_module->Add(global_var, key->source_func);
+      value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule(), {}, ir_module);
+      return value;
+    }
+    // Enforce use the target.
+    With<Target> target_scope(key->target);
+
+    ICHECK(!value->cached_func.defined());
+    auto cfunc = PrimFuncFor(key->source_func, key->target,
+                             [&](std::string name) { return GetUniqueName(name, name_map_); });
+
+    // Skip lowering for device copy node.
+    const Expr body = (key->source_func)->body;
+    if (const CallNode* call_node = body.as<CallNode>()) {
+      if (call_node->attrs.as<DeviceCopyAttrs>()) {
+        value->cached_func = cfunc;
+        return value;
+      }
+    }
+
+    std::cout << "Input Size: " << cfunc->inputs.size() << std::endl;
+    std::cout << "Output Size: " << cfunc->outputs.size() << std::endl;
+    // NOTE: array will copy on write.
+    Array<te::Tensor> all_args = Array<te::Tensor>(cfunc->inputs);
+    for (te::Tensor arg : cfunc->outputs) {
+      all_args.push_back(arg);
+    }
+
+    std::cout << "Allargs Size: " << all_args.size() << std::endl;

Review comment:
       ```suggestion
   ```

##########
File path: src/relay/backend/graph_executor_codegen.cc
##########
@@ -349,65 +380,47 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
     return AddNode(node, GetRef<Expr>(op));
   }
 
-  std::vector<GraphNodeRef> VisitExpr_(const CallNode* op) override {
-    Expr expr = GetRef<Expr>(op);
-    Function func;
-    if (op->op.as<OpNode>()) {
-      LOG(FATAL) << "Operators should be transformed away; try applying"
-                 << "the fuse_ops transformation to the expression.";
-    } else if (op->op.as<GlobalVarNode>()) {
-      LOG(FATAL) << "Not implemented";
-    } else if (op->op.as<FunctionNode>()) {
-      func = GetRef<Function>(op->op.as<FunctionNode>());
-    } else {
-      LOG(FATAL) << "TVM runtime does not support calls to " << op->op->GetTypeKey();
-    }
-    if (!func->HasNonzeroAttr(attr::kPrimitive)) {
-      LOG(FATAL) << "TVM only support calls to primitive functions "
-                 << "(i.e functions composed of fusable operator invocations)";
-    }
+  std::vector<GraphNodeRef> VisitExpr_(const CallNode* call_node) override {
+    relay::Call call = GetRef<Call>(call_node);
+    if (auto global_node = call->op.as<GlobalVarNode>()) {
 
-    auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey");
-    auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower");
-    Target target;
-    // Handle external function
-    if (func->GetAttr<String>(attr::kCompiler).defined()) {
-      target = Target("ext_dev");
-      CCacheKey key = (*pf0)(func, target);
-      CachedFunc ext_func = (*pf1)(compile_engine_, key);
-      ICHECK(ext_func.defined()) << "External function is not defined.";
-      UpdateConstants(func, &params_);
-      return GraphAddCallNode(op, ext_func->func_name, ext_func->func_name);
-    }
+      auto prim_fn_name = global_node->name_hint;
 
-    ICHECK_GE(storage_device_map_.count(expr), 0);
-    auto& device_type = storage_device_map_[expr][1];
-    auto call_dev_type = device_type[0]->value;
-    // Normal Relay Function
-    if (targets_.size() == 1) {
-      // homogeneous execution.
-      const auto& it = targets_.begin();
-      target = (*it).second;
-    } else {
-      // heterogeneous execution.
-      std::string call_dev_name;
-      if (call_dev_type == 0) {
-        call_dev_name = "llvm";
+      Target target;
+
+      // // Handle external function
+      // if (func->GetAttr<String>(attr::kCompiler).defined()) {
+      //   UpdateConstants(func, &params_);
+      //   return GraphAddCallNode(call_node, prim_fn_name, prim_fn_name);
+      // }
+
+      ICHECK_GE(storage_device_map_.count(call), 0);
+      auto& device_type = storage_device_map_[call][1];
+      auto call_dev_type = device_type[0]->value;
+      // Normal Relay Function
+      if (targets_.size() == 1) {
+        // homogeneous execution.
+        const auto& it = targets_.begin();
+        target = (*it).second;
       } else {
-        call_dev_name = runtime::DeviceName(call_dev_type);
-      }
-      if (targets_.count(call_dev_type) == 0) {
-        LOG(FATAL) << "No target is provided for device " << call_dev_name;
+        // heterogeneous execution.
+        std::string call_dev_name;
+        if (call_dev_type == 0) {
+          call_dev_name = "llvm";
+        } else {
+          call_dev_name = runtime::DeviceName(call_dev_type);
+        }
+        if (targets_.count(call_dev_type) == 0) {
+          LOG(FATAL) << "No target is provided for device " << call_dev_name;
+        }
+        target = targets_[call_dev_type];
       }
-      target = targets_[call_dev_type];
-    }
-    CCacheKey key = (*pf0)(func, target);
-    CachedFunc lowered_func = (*pf1)(compile_engine_, key);
-    if (!lowered_funcs_.count(target->str())) {
-      lowered_funcs_[target->str()] = IRModule(Map<GlobalVar, BaseFunc>({}));
+
+      return GraphAddCallNode(call_node, _GetUniqueName(prim_fn_name), prim_fn_name);
+    } else {
+      LOG(FATAL) << "BadCase: " << PrettyPrint(call) << std::endl;

Review comment:
       ```suggestion
         LOG(FATAL) << "Unhandled call type, GraphExec can only lower primitive function calls: " << PrettyPrint(call) << std::endl;
   ```

##########
File path: src/target/llvm/llvm_module.cc
##########
@@ -234,7 +238,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
       }
       funcs.push_back(f);
     }
-    ICHECK(funcs.size() > 0 || (could_have_linked_params && found_linked_params));
+    // ICHECK(funcs.size() > 0 || (could_have_linked_params && found_linked_params));

Review comment:
       This should still be a valid check, no?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch commented on a change in pull request #7518: [RFC]TECompiler: Staged refactor and removal of compile engine

Posted by GitBox <gi...@apache.org>.
jroesch commented on a change in pull request #7518:
URL: https://github.com/apache/tvm/pull/7518#discussion_r664942432



##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -0,0 +1,760 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "te_compiler.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <tuple>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "te_compiler.h"
+#include "te_compiler_cache.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+// TODO(@jroesch, @csullivan): declare directly elsewhere
+backend::StaticMemoryPlan GraphPlanMemory(const Function& func);
+
+namespace tec {
+
+using namespace tvm::relay::transform;
+
+TVM_REGISTER_OBJECT_TYPE(TECompilerNode);
+
+class TECompilerImpl : public TECompilerNode {
+ public:
+  // Lower the function.
+  CachedFunc Lower(const CCacheKey& key, std::function<String(String)> mangle_fn) {
+    return LowerInternal(key, mangle_fn)->cached_func;
+  }
+
+  CachedFunc Lower(const CCacheKey& key, const String mod_name) {
+    auto mangle_fn = [mod_name](String name) {
+      std::cout << "inner mod name" << mod_name << std::endl;
+      return runtime::get_name_mangled(mod_name, name);
+    };
+
+    return Lower(key, mangle_fn);
+  }
+
+  // For now, build one module per function.
+  PackedFunc JIT(const CCacheKey& key) final {
+    auto mangle_fn = [](String name) { return name; };
+    CCacheValue value = LowerInternal(key, mangle_fn);
+    if (value->packed_func != nullptr) {
+      return value->packed_func;
+    }
+    auto m = build(value->cached_func->funcs, key->target, Target(nullptr));
+    value->packed_func = m.GetFunction(value->cached_func->prim_fn_var->name_hint);
+    return value->packed_func;
+  }
+
+  CachedFunc LowerShapeFunc(const CCacheKey& key) final {
+    return LowerShapeFuncInternal(key)->cached_func;
+  }
+
+  Map<String, IRModule> GetLoweredFunctions() {
+    Map<String, IRModule> lowered_functions;
+    for (const auto& it : cache_) {
+      auto source_func = it.first;
+      auto lowered_func = it.second;
+      auto target = source_func->target;
+
+      if (!lowered_functions.count(target->str())) {
+        lowered_functions.Set(target->str(), IRModule(Map<GlobalVar, BaseFunc>({})));
+      }
+
+      lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs);
+    }
+    return lowered_functions;
+  }
+
+  Array<tvm::runtime::Module> LowerExternalFunctions() {
+    Array<tvm::runtime::Module> ret;
+    std::unordered_map<std::string, std::string> cached_symbol;
+    std::vector<CCacheKey> cached_ext_funcs;
+    for (const auto& it : cache_) {
+      auto src_func = it.first->source_func;
+      ICHECK(src_func.defined());
+      if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
+        auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
+        std::string code_gen_name = code_gen.value();
+        cached_ext_funcs.push_back(it.first);
+
+        auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+        ICHECK(symbol_name.defined()) << "No external symbol is set for:\n"
+                                      << AsText(src_func, false);
+
+        std::string sn = symbol_name.value();
+        if (cached_symbol.count(sn)) {
+          cached_symbol[sn] = code_gen_name;
+        } else {
+          ICHECK_NE(sn, code_gen_name)
+              << "Found duplicated symbol: " << sn << " for: " << code_gen_name;
+        }
+
+        std::string ext_name = "relay.ext." + code_gen_name;
+        auto pf = tvm::runtime::Registry::Get(ext_name);
+        ICHECK(pf) << "Failed to find the codegen tool for " << ext_name;
+        // No need to keep compiler attribute at this point, functions have been
+        // extracted for specific codegen.
+        src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue<ObjectRef>());
+        runtime::Module ext_mod = (*pf)(src_func);
+
+        ICHECK(ext_mod.defined()) << "No external runtime is generated.";
+        ret.push_back(ext_mod);
+      }
+    }
+
+    // No need to cache external functions as we collected them all to create
+    // external runtime modules.
+    for (const auto& it : cached_ext_funcs) {
+      cache_.erase(it);
+    }
+    return ret;
+  }
+
+  void Clear() final { cache_.clear(); }
+
+  // List all items in the cache.
+  Array<ObjectRef> ListItems() {
+    std::lock_guard<std::mutex> lock(mutex_);
+    Array<ObjectRef> items;
+    for (auto& kv : cache_) {
+      items.push_back(kv.first);
+      items.push_back(kv.second);
+    }
+    return items;
+  }
+
+  /*!
+   * \brief Get the cache key of the function that is being lowered currently
+   * \return the cache key
+   */
+  CCacheKey GetCurrentCCacheKey() { return cur_ccache_key_; }
+
+ private:
+  // implement lowered func
+  CCacheValue LowerInternal(const CCacheKey& key, std::function<String(String)> mangle_fn) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    CCacheValue value;
+    auto it = cache_.find(key);
+    if (it != cache_.end()) {
+      it->second->use_count += 1;
+      if (it->second->cached_func.defined()) return it->second;
+      value = it->second;
+    } else {
+      value = CCacheValue(make_object<CCacheValueNode>());
+      value->use_count = 1;
+      cache_[key] = value;
+    }
+    cur_ccache_key_ = key;
+
+    // No need to lower external functions for now. We will invoke the external
+    // codegen tool once and lower all functions together.
+    if (key->source_func->GetAttr<String>(attr::kCompiler).defined()) {
+      auto ir_module = IRModule();
+      const auto name_node = key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+      ICHECK(name_node.defined()) << "External function has not been attached a name yet.";
+      auto func_name = GetUniqueName(name_node.value(), &name_map_);
+      auto target = Target("ext_dev");
+      auto global_var = GlobalVar(func_name);
+      global_var->checked_type_ = key->source_func->checked_type();
+      value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule(), {}, ir_module);
+      return value;
+    }
+
+    // Enforce use the target.
+    With<Target> target_scope(key->target);
+
+    ICHECK(!value->cached_func.defined());
+    auto cfunc = PrimFuncFor(key->source_func, key->target, [&](std::string name) {
+      auto mangled = mangle_fn(name);
+      std::cout << "Mangled: " << mangled << std::endl;
+      return GetUniqueName(mangled, &name_map_);
+    });
+
+    // Skip lowering for device copy node.
+    const Expr body = (key->source_func)->body;
+    if (const CallNode* call_node = body.as<CallNode>()) {
+      if (call_node->attrs.as<DeviceCopyAttrs>()) {
+        value->cached_func = cfunc;
+        return value;
+      }
+    }
+
+    // NOTE: array will copy on write.
+    Array<te::Tensor> all_args = Array<te::Tensor>(cfunc->inputs);
+    for (te::Tensor arg : cfunc->outputs) {
+      all_args.push_back(arg);
+    }
+
+    std::unordered_map<te::Tensor, tir::Buffer> binds;
+    auto func_name = cfunc->prim_fn_var->name_hint;
+    cfunc->funcs->Update(tvm::LowerSchedule(cfunc->schedule, all_args, func_name, binds));
+    value->cached_func = cfunc;
+    return value;
+  }
+
+  // implement lowered shape func
+  CCacheValue LowerShapeFuncInternal(const CCacheKey& key) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    CCacheValue value;
+    auto it = shape_func_cache_.find(key);
+    if (it != shape_func_cache_.end()) {
+      it->second->use_count += 1;
+      if (it->second->cached_func.defined()) return it->second;
+      value = it->second;
+    } else {
+      value = CCacheValue(make_object<CCacheValueNode>());
+      value->use_count = 0;
+      shape_func_cache_[key] = value;
+    }
+    // Enforce use the target.
+    With<Target> target_scope(key->target);
+
+    ICHECK(!value->cached_func.defined());
+
+    using tvm::transform::PassContext;
+    With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
+    auto cached_func = ShapeFuncFor(key->source_func, key->target, [&](std::string name) {
+      return GetUniqueName(name, &name_map_);
+    });
+
+    value->cached_func = cached_func;
+    return value;
+  }
+
+  std::unordered_map<std::string, int> GetOpWeights() {
+    std::unordered_map<std::string, int> weights;
+    for (auto pair : cache_) {
+      auto value = pair.second;
+      auto name = value->cached_func->prim_fn_var->name_hint;
+      weights[name] = value->use_count;
+    }
+    return weights;
+  }
+
+  /*! \brief compiler cache lock*/
+  std::mutex mutex_;
+  /*! \brief internal name map to get an unique name */
+  std::unordered_map<std::string, int> name_map_;
+  /*! \brief internal compiler cache */
+  std::unordered_map<CCacheKey, CCacheValue> cache_;
+  /*! \brief internal compiler cache for shape funcs */
+  std::unordered_map<CCacheKey, CCacheValue> shape_func_cache_;
+  /*! \brief the cache key of the function that is being lowered currently*/
+  CCacheKey cur_ccache_key_;
+};
+
+TECompiler::TECompiler() {
+  auto object = make_object<TECompilerImpl>();
+  data_ = object;
+}
+
+using AnalysisRemapping = std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual>;
+
+std::tuple<bool, int, int> IsDeviceCopy(const Function& func) {
+  if (auto call_node = func->body.as<CallNode>()) {
+    if (auto op_node = call_node->op.as<OpNode>()) {
+      if (op_node->name == "device_copy") {
+        auto attrs = call_node->attrs.as<DeviceCopyAttrs>();
+        auto dst = attrs->dst_dev_type;
+        auto src = attrs->src_dev_type;
+        return std::tuple<bool, int, int>(true, src, dst);
+      }
+    }
+  }
+
+  return std::tuple<bool, int, int>(false, -1, -1);
+}
+
+class LowerTensorExpr : public ExprMutator {
+ public:
+  LowerTensorExpr(const IRModule& module, const TargetMap& targets, const DeviceMap& device_ctx_map,
+                  ProcessFn process_fn, AnalysisRemapping* prim_fn_to_call,
+                  const String& module_name, TECompiler compiler)
+      : module_(module),
+        targets_(targets),
+        device_context_map_(device_ctx_map),
+        process_fn(process_fn),
+        prim_fn_to_call(prim_fn_to_call),
+        module_name_(module_name),
+        compiler_(compiler) {}
+
+  Expr VisitExpr_(const CallNode* call) override {
+    Call expr = GetRef<Call>(call);
+    Function func;
+
+    if (call->op.as<FunctionNode>()) {
+      func = GetRef<Function>(call->op.as<FunctionNode>());
+    } else {
+      return ExprMutator::VisitExpr_(call);
+    }
+
+    if (!func->HasNonzeroAttr(attr::kPrimitive)) {
+      // Provide a callback hook which allows one-level up code generators to
+      // act when we process a function.
+      this->process_fn(func);
+      return ExprMutator::VisitExpr_(call);
+    }
+
+    // Process inputs.
+    Array<Expr> args;
+    for (size_t i = 0; i < expr->args.size(); i++) {
+      args.push_back(VisitExpr(expr->args[i]));
+    }
+
+    Target target;
+
+    if (func->GetAttr<String>(attr::kCompiler).defined()) {
+      target = Target("ext_dev");
+      CCacheKey key = CCacheKey(func, target);
+      CachedFunc ext_func = compiler_->Lower(key, module_name_);
+      ICHECK(ext_func.defined()) << "Lowering returned undefined function for "
+                                 << ext_func->prim_fn_var->name_hint;
+
+      Map<GlobalVar, tir::PrimFunc> prim_fns;
+
+      for (auto prim_fn : ext_func->funcs->functions) {
+        CHECK(prim_fn.second.as<tir::PrimFuncNode>()) << "must be a prim fn";
+        prim_fns.Set(prim_fn.first, Downcast<tir::PrimFunc>(prim_fn.second));
+      }
+
+      relay::Function func_with_metadata = func;
+      func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", ext_func->prim_fn_var);
+      func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns);
+      func_with_metadata = WithAttr(func_with_metadata, "target", ext_func->target);
+
+      // Provide a callback hook which allows one-level up code generators to
+      // act when we process a function.
+      this->process_fn(func_with_metadata);
+
+      auto ret_call = Call(ext_func->prim_fn_var, args, {});
+      (*prim_fn_to_call)[func] = ret_call;
+      return std::move(ret_call);
+    }
+
+    ICHECK_GE(device_context_map_.count(expr), 0)
+        << "Could not find an entry in the device context map for " << PrettyPrint(expr)
+        << "The memory planning was either not performed for this precise node, or there is bug "
+           "in the memory planner.";
+
+    auto& device_context = this->device_context_map_[expr];
+    auto call_dev_type = device_context.device_type;
+
+    // Non-External Relay Function
+    if (targets_.size() == 1) {
+      // The homogeneous execution case, we should only have one target
+      // so we just grab it.
+      const auto& it = targets_.begin();
+      target = (*it).second;
+    } else {
+      // The heterogeneous execution case we have multiple targets
+      // in this case.
+      //
+      // We need to identify the target and translate.
+      std::string call_dev_name;
+      if (call_dev_type == 0) {
+        call_dev_name = "llvm";
+        call_dev_type = kDLCPU;
+      } else {
+        call_dev_name = ::tvm::runtime::DeviceName(call_dev_type);
+      }
+
+      if (targets_.count(call_dev_type) == 0) {
+        std::stringstream msg;
+        msg << "No target is specified for provided device name: `" << call_dev_name << "`\n\n";
+        msg << call_dev_name << " mapped to device type (" << call_dev_type
+            << ") which was not found in the target map.\n";
+        msg << "Availible targets: \n";
+        for (auto target : targets_) {
+          msg << "  " << target.first << "-> " << target.second << "\n";
+        }
+        LOG(FATAL) << msg.str();
+      }
+
+      target = targets_[call_dev_type];
+    }
+
+    CCacheKey key = CCacheKey(func, target);
+    CachedFunc lowered_func = compiler_->Lower(key, module_name_);
+
+    Map<GlobalVar, tir::PrimFunc> prim_fns;
+
+    for (auto prim_fn : lowered_func->funcs->functions) {
+      CHECK(prim_fn.second.as<tir::PrimFuncNode>()) << "must be a prim fn";
+      prim_fns.Set(prim_fn.first, Downcast<tir::PrimFunc>(prim_fn.second));
+    }
+
+    // TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT
+    relay::Function func_with_metadata = func;
+    func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", lowered_func->prim_fn_var);
+    func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns);
+    func_with_metadata = WithAttr(func_with_metadata, "target", lowered_func->target);
+
+    // Provide a callback hook which allows one-level up code generators to
+    // act when we process a function.
+    this->process_fn(func_with_metadata);
+
+    auto tir_call_attrs = make_object<TIRCallAttrs>();
+    if (func->HasNonzeroAttr(attr::kReshapeOnly)) {
+      tir_call_attrs->metadata.Set(attr::kReshapeOnly, tvm::Integer(1));
+    }
+
+    auto device_copy = IsDeviceCopy(func);
+    if (std::get<0>(device_copy)) {
+      std::cout << "DeviceCopy" << std::endl;

Review comment:
       Remove this




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch commented on pull request #7518: [RFC]TECompiler: Staged refactor and removal of compile engine

Posted by GitBox <gi...@apache.org>.
jroesch commented on pull request #7518:
URL: https://github.com/apache/tvm/pull/7518#issuecomment-805594042


   Need to port fix from #7703 but otherwise ready for review. 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch commented on a change in pull request #7518: [RFC]TECompiler: Staged refactor and removal of compile engine

Posted by GitBox <gi...@apache.org>.
jroesch commented on a change in pull request #7518:
URL: https://github.com/apache/tvm/pull/7518#discussion_r664941440



##########
File path: src/relay/backend/graph_executor_codegen.cc
##########
@@ -176,112 +179,88 @@ class GraphOpNode : public GraphNode {
   const std::string op_type_name_{"tvm_op"};
 };
 
-/*! \brief Code generator for graph executor */
+/*! \brief Code generator for the graph executor, produces a module containing the graph JSON,
+ * module, and parameters.
+ */
 class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<GraphNodeRef>> {
  public:
-  GraphExecutorCodegen(runtime::Module* mod, const TargetsMap& targets) : mod_(mod) {
-    compile_engine_ = CompileEngine::Global();
+  GraphExecutorCodegen(runtime::Module* mod, const TargetMap& targets) : mod_(mod) {
     targets_ = targets;
   }
 
-  /*!
-   * \brief Update the "main" control function's metadata
-   *
-   * \param func The main function that contains calls to relay primitive functions
-   */
-  void UpdateMainWorkspaceSize(const Function& func) {
-    // This is a Map<device,Map<storage_id, size>>
-    std::unordered_map<int, std::unordered_map<int, int>> sid_workspace;
-    // This is a Map<device, size_of_inputs_and_outputs>
-    std::unordered_map<int, int> device_io;
-    // This is a Map<device, size_of_constants>
-    std::unordered_map<int, int> device_consts;
-
-    // Initialize the maps to zero
-    for (const auto& kv : storage_device_map_) {
-      auto sids = kv.second[0];
-      auto devices = kv.second[1];
-      CHECK_EQ(sids.size(), devices.size());
-      for (uint32_t i = 0; i < sids.size(); i++) {
-        sid_workspace[devices[i]][sids[i]] = 0;
-        device_io[devices[i]] = 0;
-        device_consts[devices[i]] = 0;
-      }
-    }
-
-    // Collect sizes of tensors
-    for (const auto& kv : storage_device_map_) {
-      auto size_bytes = CalculateRelayExprSizeBytes(kv.first->checked_type());
-      auto sids = kv.second[0];
-      auto devices = kv.second[1];
-      if (kv.first->IsInstance<ConstantNode>()) {
-        for (const auto& dev : devices) {
-          device_consts[dev] += size_bytes;
-        }
-        continue;
-      } else if (kv.first->IsInstance<VarNode>() || kv.first == func->body) {
-        for (const auto& dev : devices) {
-          device_io[dev] += size_bytes;
-        }
-        continue;
-      }
-      for (uint32_t i = 0; i < sids.size(); i++) {
-        // Here we record the largest size of the tensor
-        // that share the same storage id, because storage_id will
-        // be shared between multiple tensors that are not live simultaneously.
-        if (size_bytes > sid_workspace[devices[i]][sids[i]]) {
-          sid_workspace[devices[i]][sids[i]] = size_bytes;
-        }
-      }
-    }
-
-    // This is a Map<device, workspace_size>
-    std::unordered_map<int, int> device_workspace;
-    // Once we know the sizes of sids, we need to accumulate per device
-    for (const auto& dev_sid_size : sid_workspace) {
-      auto dev = dev_sid_size.first;
-      device_workspace[dev] = 0;
-      for (const auto& sid_size : dev_sid_size.second) {
-        device_workspace[dev] += sid_size.second;
-      }
-    }
-
-    // Populate FunctionInfo
-    auto fi_node = make_object<FunctionInfoNode>();
-    // Initialize all target workspaces to zero
-    for (const auto& kv : targets_) {
-      auto tgt = kv.second;
-      fi_node->workspace_sizes.Set(tgt, 0);
-    }
-    for (const auto& dev_and_size : device_workspace) {
-      auto tgt = GetTargetFromInteger(dev_and_size.first);
-      fi_node->workspace_sizes.Set(tgt, dev_and_size.second);
-      fi_node->relay_primfuncs.Set(tgt, func);
-    }
-    for (const auto& dev_and_size : device_io) {
-      auto tgt = GetTargetFromInteger(dev_and_size.first);
-      fi_node->io_sizes.Set(tgt, dev_and_size.second);
-    }
-    for (const auto& dev_and_size : device_consts) {
-      auto tgt = GetTargetFromInteger(dev_and_size.first);
-      fi_node->constant_sizes.Set(tgt, dev_and_size.second);
-    }
-
-    function_metadata_.Set(String(runtime::symbol::tvm_module_main), FunctionInfo(fi_node));
+  StorageInfo GetStorageInfo(const Expr& e) {
+    size_t count = memory_plan_->expr_to_storage_info.count(e);
+    ICHECK_GT(count, 0) << "Expr is not existing in storage plan";
+    auto storage_info = memory_plan_->expr_to_storage_info[e];
+    return storage_info;
   }
 
   LoweredOutput Codegen(relay::Function func, String mod_name) {
-    auto pf = GetPackedFunc("relay.backend.GraphPlanMemory");
-    storage_device_map_ = (*pf)(func);
     mod_name_ = mod_name;
-    UpdateMainWorkspaceSize(func);
+
+    std::cout << "MODULE_NAME: " << mod_name_ << std::endl;

Review comment:
       Remove




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch commented on pull request #7518: [RFC]TECompiler: Staged refactor and removal of compile engine

Posted by GitBox <gi...@apache.org>.
jroesch commented on pull request #7518:
URL: https://github.com/apache/tvm/pull/7518#issuecomment-805594912


   Modulo some left over polish work and documentation I think this is ready for review @icemelon9 @comaniac @csullivan @tkonolige @rkimball @junrushao1994 @areusch @mehrdadh 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch commented on a change in pull request #7518: [RFC]TECompiler: Staged refactor and removal of compile engine

Posted by GitBox <gi...@apache.org>.
jroesch commented on a change in pull request #7518:
URL: https://github.com/apache/tvm/pull/7518#discussion_r664943391



##########
File path: tests/python/unittest/test_micro_model_library_format.py
##########
@@ -267,7 +267,7 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[
     "target",
     [

Review comment:
       Restore




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch commented on a change in pull request #7518: [RFC]TECompiler: Staged refactor and removal of compile engine

Posted by GitBox <gi...@apache.org>.
jroesch commented on a change in pull request #7518:
URL: https://github.com/apache/tvm/pull/7518#discussion_r664943131



##########
File path: src/target/llvm/llvm_module.cc
##########
@@ -234,7 +238,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
       }
       funcs.push_back(f);
     }
-    ICHECK(funcs.size() > 0 || (could_have_linked_params && found_linked_params));
+    // ICHECK(funcs.size() > 0 || (could_have_linked_params && found_linked_params));

Review comment:
       Add follow-up comment here.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] areusch commented on a change in pull request #7518: [RFC]TECompiler: Staged refactor and removal of compile engine

Posted by GitBox <gi...@apache.org>.
areusch commented on a change in pull request #7518:
URL: https://github.com/apache/tvm/pull/7518#discussion_r604192779



##########
File path: src/relay/backend/graph_runtime_codegen.cc
##########
@@ -181,23 +186,56 @@ class GraphOpNode : public GraphNode {
   const std::string op_type_name_{"tvm_op"};
 };
 
-/*! \brief Code generator for graph runtime */
+/*! \brief Code generator for the graph runtime, produces a module containing the graph JSON,
+ * module, and parameters.
+ */
 class GraphRuntimeCodegen : public backend::MemoizedExprTranslator<std::vector<GraphNodeRef>> {
  public:
   GraphRuntimeCodegen(runtime::Module* mod, const TargetsMap& targets) : mod_(mod) {
-    compile_engine_ = CompileEngine::Global();
     targets_ = targets;
   }
 
   LoweredOutput Codegen(relay::Function func) {
-    auto pf = GetPackedFunc("relay.backend.GraphPlanMemory");
-    storage_device_map_ = (*pf)(func);
+    // Jared: why do we do this? just call C++ API.

Review comment:
       I don't think it makes sense to allow people to overload GraphPlanMemory with its current implementation. If GraphPlanMemory were to support memory pools and tensor pinning, then I think there will be cases where people want to implement custom algorithms to place tensors in different memory pools. 
   
   I also think the reason this "seems bad" is that there is no good interface/data structure defined here. There's no reason we can't define one--in fact, I think we are hurting compiler readability by not doing so. Here we are keeping the interface to a sort of Pythonic "informal set of lists" in place of a proper data structure. Here and downstream of here, there are many instances where both TVM code and people trying to integrate with TVM code would benefit from effectively exporting StorageToken as a user-facing data structure and clearly defining its meaning.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch commented on pull request #7518: [RFC]TECompiler: Staged refactor and removal of compile engine

Posted by GitBox <gi...@apache.org>.
jroesch commented on pull request #7518:
URL: https://github.com/apache/tvm/pull/7518#issuecomment-805600377


   cc @MatthewARM 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch commented on a change in pull request #7518: [RFC]TECompiler: Staged refactor and removal of compile engine

Posted by GitBox <gi...@apache.org>.
jroesch commented on a change in pull request #7518:
URL: https://github.com/apache/tvm/pull/7518#discussion_r603734939



##########
File path: src/relay/backend/graph_runtime_codegen.cc
##########
@@ -181,23 +186,56 @@ class GraphOpNode : public GraphNode {
   const std::string op_type_name_{"tvm_op"};
 };
 
-/*! \brief Code generator for graph runtime */
+/*! \brief Code generator for the graph runtime, produces a module containing the graph JSON,
+ * module, and parameters.
+ */
 class GraphRuntimeCodegen : public backend::MemoizedExprTranslator<std::vector<GraphNodeRef>> {
  public:
   GraphRuntimeCodegen(runtime::Module* mod, const TargetsMap& targets) : mod_(mod) {
-    compile_engine_ = CompileEngine::Global();
     targets_ = targets;
   }
 
   LoweredOutput Codegen(relay::Function func) {
-    auto pf = GetPackedFunc("relay.backend.GraphPlanMemory");
-    storage_device_map_ = (*pf)(func);
+    // Jared: why do we do this? just call C++ API.

Review comment:
       I think we should ask what do we allow customization of what and where ? I think allowing people to wholesale replace the lowering/memory/planning/etc in an ad-hoc way "seems bad", which effectively is all of these hooks (that I'm removing) do right now. They allow you to just arbitrarily overload parts of the compiler with custom code which is pretty non-compositional and hard to reason about since in certain cases, i.e Python loaded or not the behavior of the compiler is completely different. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch commented on a change in pull request #7518: [RFC]TECompiler: Staged refactor and removal of compile engine

Posted by GitBox <gi...@apache.org>.
jroesch commented on a change in pull request #7518:
URL: https://github.com/apache/tvm/pull/7518#discussion_r664941736



##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -0,0 +1,760 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "te_compiler.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <tuple>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "te_compiler.h"
+#include "te_compiler_cache.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+// TODO(@jroesch, @csullivan): declare directly elsewhere
+backend::StaticMemoryPlan GraphPlanMemory(const Function& func);
+
+namespace tec {
+
+using namespace tvm::relay::transform;
+
+TVM_REGISTER_OBJECT_TYPE(TECompilerNode);
+
+class TECompilerImpl : public TECompilerNode {
+ public:
+  // Lower the function.
+  CachedFunc Lower(const CCacheKey& key, std::function<String(String)> mangle_fn) {
+    return LowerInternal(key, mangle_fn)->cached_func;
+  }
+
+  CachedFunc Lower(const CCacheKey& key, const String mod_name) {
+    auto mangle_fn = [mod_name](String name) {
+      std::cout << "inner mod name" << mod_name << std::endl;

Review comment:
       Remove




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch commented on a change in pull request #7518: [RFC]TECompiler: Staged refactor and removal of compile engine

Posted by GitBox <gi...@apache.org>.
jroesch commented on a change in pull request #7518:
URL: https://github.com/apache/tvm/pull/7518#discussion_r608923219



##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -0,0 +1,372 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "te_compiler.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "te_compiler_cache.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+using namespace tvm::relay::transform;
+
+TVM_REGISTER_OBJECT_TYPE(TECompilerNode);
+
+class TECompilerImpl : public TECompilerNode {
+ public:
+  // Lower the function.
+  CachedFunc Lower(const CCacheKey& key) { return LowerInternal(key)->cached_func; }
+
+  // For now, build one module per function.
+  PackedFunc JIT(const CCacheKey& key) final {
+    CCacheValue value = LowerInternal(key);
+    if (value->packed_func != nullptr) {
+      return value->packed_func;
+    }
+    auto m = build(value->cached_func->funcs, key->target, Target(nullptr));
+    value->packed_func = m.GetFunction(value->cached_func->prim_fn_var->name_hint);
+    return value->packed_func;
+  }
+
+  CachedFunc LowerShapeFunc(const CCacheKey& key) final {
+    return LowerShapeFuncInternal(key)->cached_func;
+  }
+
+  Map<String, IRModule> GetLoweredFunctions() {
+    Map<String, IRModule> lowered_functions;
+    for (const auto& it : cache_) {
+      auto source_func = it.first;
+      auto lowered_func = it.second;
+      auto target = source_func->target;
+
+      if (!lowered_functions.count(target->str())) {
+        lowered_functions.Set(target->str(), IRModule(Map<GlobalVar, BaseFunc>({})));
+      }
+
+      lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs);
+    }
+    return lowered_functions;
+  }
+
+  Array<tvm::runtime::Module> LowerExternalFunctions() {
+    Array<tvm::runtime::Module> ret;
+    std::unordered_map<std::string, std::string> cached_symbol;
+    std::vector<CCacheKey> cached_ext_funcs;
+    for (const auto& it : cache_) {
+      auto src_func = it.first->source_func;
+      ICHECK(src_func.defined());
+      if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
+        auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
+        ICHECK(code_gen.defined()) << "No external codegen is set";
+        std::string code_gen_name = code_gen.value();
+        cached_ext_funcs.push_back(it.first);
+
+        auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+        ICHECK(symbol_name.defined()) << "No external symbol is set for:\n"
+                                      << AsText(src_func, false);
+
+        std::string sn = symbol_name.value();
+        if (cached_symbol.count(sn)) {
+          cached_symbol[sn] = code_gen_name;
+        } else {
+          ICHECK_NE(sn, code_gen_name)
+              << "Found duplicated symbol: " << sn << " for: " << code_gen_name;
+        }
+
+        std::string ext_name = "relay.ext." + code_gen_name;
+        auto pf = tvm::runtime::Registry::Get(ext_name);
+        ICHECK(pf) << "Failed to find the codegen tool for " << ext_name << "\n";
+        // No need to keep compiler attribute at this point, functions have been
+        // extracted for specific codegen.
+        src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue<ObjectRef>());
+        runtime::Module ext_mod = (*pf)(src_func);
+
+        ICHECK(ext_mod.defined()) << "No external runtime is generated.";
+        ret.push_back(ext_mod);
+      }
+    }
+
+    // No need to cache external functions as we collected them all to create
+    // external runtime modules.
+    for (const auto& it : cached_ext_funcs) {
+      cache_.erase(it);
+    }
+    return ret;
+  }
+
+  void Clear() final { cache_.clear(); }
+
+  // List all items in the cache.
+  Array<ObjectRef> ListItems() {
+    std::lock_guard<std::mutex> lock(mutex_);
+    Array<ObjectRef> items;
+    for (auto& kv : cache_) {
+      items.push_back(kv.first);
+      items.push_back(kv.second);
+    }
+    return items;
+  }
+
+  /*!
+   * \brief Get the cache key of the function that is being lowered currently
+   * \return the cache key
+   */
+  CCacheKey GetCurrentCCacheKey() { return cur_ccache_key_; }
+
+ private:
+  // implement lowered func
+  CCacheValue LowerInternal(const CCacheKey& key) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    CCacheValue value;
+    auto it = cache_.find(key);
+    if (it != cache_.end()) {
+      it->second->use_count += 1;
+      if (it->second->cached_func.defined()) return it->second;
+      value = it->second;
+    } else {
+      value = CCacheValue(make_object<CCacheValueNode>());
+      value->use_count = 0;
+      if (!backend::IsCompileEngineCacheDisabled()) {
+        cache_[key] = value;
+      }
+    }
+    cur_ccache_key_ = key;
+
+    // No need to lower external functions for now. We will invoke the external
+    // codegen tool once and lower all functions together.
+    if (key->source_func->GetAttr<String>(attr::kCompiler).defined()) {
+      auto ir_module = IRModule();
+      const auto name_node = key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+      ICHECK(name_node.defined()) << "External function has not been attached a name yet.";

Review comment:
       There is no name in this case?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] areusch commented on a change in pull request #7518: [RFC]TECompiler: Staged refactor and removal of compile engine

Posted by GitBox <gi...@apache.org>.
areusch commented on a change in pull request #7518:
URL: https://github.com/apache/tvm/pull/7518#discussion_r603671640



##########
File path: src/relay/backend/graph_runtime_codegen.cc
##########
@@ -181,23 +186,56 @@ class GraphOpNode : public GraphNode {
   const std::string op_type_name_{"tvm_op"};
 };
 
-/*! \brief Code generator for graph runtime */
+/*! \brief Code generator for the graph runtime, produces a module containing the graph JSON,
+ * module, and parameters.
+ */
 class GraphRuntimeCodegen : public backend::MemoizedExprTranslator<std::vector<GraphNodeRef>> {
  public:
   GraphRuntimeCodegen(runtime::Module* mod, const TargetsMap& targets) : mod_(mod) {
-    compile_engine_ = CompileEngine::Global();
     targets_ = targets;
   }
 
   LoweredOutput Codegen(relay::Function func) {
-    auto pf = GetPackedFunc("relay.backend.GraphPlanMemory");
-    storage_device_map_ = (*pf)(func);
+    // Jared: why do we do this? just call C++ API.

Review comment:
       at some point we should support overriding this. it seems like you'd need to build a C++ impl right now to do that, so I don't know that we need this hook now. but, we should plan to put some customization back in the future?

##########
File path: src/relay/backend/graph_plan_memory.cc
##########
@@ -252,9 +265,22 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
   }
   // The call map
   void VisitExpr_(const CallNode* op) final {
+    // temporary hack for change of style.

Review comment:
       same

##########
File path: src/driver/driver_api.cc
##########
@@ -244,14 +244,17 @@ std::pair<IRModule, IRModule> SplitDevHostFuncs(IRModule mod_mixed, const Target
   }
 
   if (target->kind->device_type == kDLCPU && target_host == target) {
-    ICHECK(mdevice->functions.empty()) << "No device code should be generated when target "
-                                       << "and host_target are both llvm target."
-                                       << "\n";
+    // We need to relax this check for just TIR functions.

Review comment:
       +1

##########
File path: src/relay/backend/te_compiler_cache.cc
##########
@@ -0,0 +1,681 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "./te_compiler_cache.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+TVM_REGISTER_NODE_TYPE(LoweredOutputNode);
+TVM_REGISTER_NODE_TYPE(CachedFuncNode);
+TVM_REGISTER_NODE_TYPE(CCacheKeyNode);
+TVM_REGISTER_NODE_TYPE(CCacheValueNode);
+
+LoweredOutput::LoweredOutput(tvm::Array<te::Tensor> outputs, OpImplementation impl) {
+  auto n = make_object<LoweredOutputNode>();
+  n->outputs = std::move(outputs);
+  n->implementation = std::move(impl);
+  data_ = std::move(n);
+}
+
+CCacheKey::CCacheKey(Function source_func, Target target) {
+  auto n = make_object<CCacheKeyNode>();
+  n->source_func = std::move(source_func);
+  n->target = std::move(target);
+  data_ = std::move(n);
+}
+
+CachedFunc::CachedFunc(tvm::Target target, GlobalVar prim_fn_var, tvm::Array<te::Tensor> inputs,
+                       tvm::Array<te::Tensor> outputs, te::Schedule schedule,
+                       tvm::Array<Integer> shape_func_param_states, IRModule funcs) {
+  auto n = make_object<CachedFuncNode>();
+  n->target = target;
+  n->prim_fn_var = prim_fn_var;
+  n->inputs = inputs;
+  n->outputs = outputs;
+  n->schedule = schedule;
+  n->shape_func_param_states = shape_func_param_states;
+  n->funcs = funcs;
+  data_ = std::move(n);
+}
+
+Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
+  // for now, we always use int32 shape when possible
+  // even if the result of shape inference becomes int64.
+  Array<IndexExpr> res;
+  for (IndexExpr val : shape) {
+    const int64_t* pval = tir::as_const_int(val);
+    if (pval != nullptr) {
+#ifndef TVM_INDEX_DEFAULT_I64
+      ICHECK_LE(pval[0], std::numeric_limits<int32_t>::max());
+      ICHECK_GE(pval[0], std::numeric_limits<int32_t>::min());
+      res.push_back(IntImm(DataType::Int(32), *pval));
+#else
+      res.push_back(val);
+#endif  // TVM_INDEX_DEFAULT_I64
+    } else if (val->IsInstance<tir::AnyNode>()) {
+      res.push_back(val.as<tir::AnyNode>()->ToVar());
+    } else {
+      res.push_back(val);
+    }
+  }
+  return res;
+}
+
+// The getter to get schedule from compile engine.

Review comment:
       this isn't really a Getter is it? What does it translate from and to?

##########
File path: tests/python/relay/test_backend_graph_runtime.py
##########
@@ -142,7 +143,7 @@ def test_plan_memory():
     # Current rule requires vars have unique storage id
     # because we don't do inplace, we will need another
     # two alternating temporary space.
-    assert len(storage_ids) == 4
+    assert len(storage_ids) == 4, f"found storaged_ids: {storage_ids}"

Review comment:
       nit: storage_ids

##########
File path: tests/python/relay/test_backend_graph_runtime.py
##########
@@ -126,6 +126,7 @@ def test_plan_memory():
     func = relay.Function([x, y], z)
     mod = tvm.IRModule.from_expr(func)
     mod = relay.transform.InferType()(mod)
+    print(mod)

Review comment:
       remove?

##########
File path: tests/python/relay/test_backend_graph_runtime.py
##########
@@ -231,10 +232,12 @@ def test_graph_executor_nested_tuples():
 
 
 if __name__ == "__main__":
-    test_plan_memory()
-    test_with_params()
-    test_add_op_scalar()
     test_add_op_tensor()
-    test_add_op_broadcast()
-    test_gru_like()
-    test_compile_nested_tuples()
+    # test_plan_memory()
+    # test_with_params()
+    # test_add_op_scalar()
+    # test_add_op_tensor()
+    # test_add_op_broadcast()
+    # test_gru_like()
+    # test_compile_nested_tuples()
+    # test_add_op_tensor()

Review comment:
       sys.exit(pytest.main([__file__] + sys.argv[1:]))

##########
File path: src/relay/backend/graph_plan_memory.cc
##########
@@ -166,10 +167,22 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor {
   }
 
   void VisitExpr_(const CallNode* op) final {
+    // temporary hack for change of style.

Review comment:
       can you state which arg is being dropped here?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] icemelon9 commented on a change in pull request #7518: [RFC]TECompiler: Staged refactor and removal of compile engine

Posted by GitBox <gi...@apache.org>.
icemelon9 commented on a change in pull request #7518:
URL: https://github.com/apache/tvm/pull/7518#discussion_r609055578



##########
File path: src/relay/backend/te_compiler.h
##########
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file relay/backend/tir_compiler.h
+ *  * \brief Internal compilation layer which lowers Relay "primitive functions" to TIR PrimFns.
+ *
+ *
+ * This represents the new design of the Relay compilation flow and will replace the interface
+ * contained in compile_engine.h as we migrate towards a standard pass based lowering of
+ * Relay functions.
+ *
+ * This files provides an internal API which lowers Relay programs to components which
+ * can be combined with TVM produced kernels to compile an entire program.
+ *
+ * The result of lowering contains a combination of `runtime::Module`s produced by external
+ * compilers and a set of lowered PrimFns which can be code generated for targets.
+ */
+#ifndef TVM_RELAY_BACKEND_TE_COMPILER_H_
+#define TVM_RELAY_BACKEND_TE_COMPILER_H_
+
+#include <tvm/node/structural_equal.h>
+#include <tvm/node/structural_hash.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/op_strategy.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/module.h>
+#include <tvm/topi/elemwise.h>
+
+#include <functional>
+#include <string>
+#include <unordered_map>
+
+#include "../transforms/infer_layout_utils.h"
+#include "../transforms/pass_utils.h"
+#include "./te_compiler_cache.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+// TODO(@jroesch, @chrisS) these should be a tvm::Map for uniformity sake
+// we should a version of context which works in Map
+using TargetsMap = std::unordered_map<int, Target>;

Review comment:
       ```suggestion
   using TargetMap = std::unordered_map<int, Target>;
   ```

##########
File path: src/relay/backend/te_compiler.h
##########
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file relay/backend/tir_compiler.h
+ *  * \brief Internal compilation layer which lowers Relay "primitive functions" to TIR PrimFns.
+ *
+ *
+ * This represents the new design of the Relay compilation flow and will replace the interface
+ * contained in compile_engine.h as we migrate towards a standard pass based lowering of
+ * Relay functions.
+ *
+ * This files provides an internal API which lowers Relay programs to components which
+ * can be combined with TVM produced kernels to compile an entire program.
+ *
+ * The result of lowering contains a combination of `runtime::Module`s produced by external
+ * compilers and a set of lowered PrimFns which can be code generated for targets.
+ */
+#ifndef TVM_RELAY_BACKEND_TE_COMPILER_H_
+#define TVM_RELAY_BACKEND_TE_COMPILER_H_
+
+#include <tvm/node/structural_equal.h>
+#include <tvm/node/structural_hash.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/op_strategy.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/module.h>
+#include <tvm/topi/elemwise.h>
+
+#include <functional>
+#include <string>
+#include <unordered_map>
+
+#include "../transforms/infer_layout_utils.h"
+#include "../transforms/pass_utils.h"
+#include "./te_compiler_cache.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+// TODO(@jroesch, @chrisS) these should be a tvm::Map for uniformity sake
+// we should a version of context which works in Map
+using TargetsMap = std::unordered_map<int, Target>;
+using DeviceContextMap =

Review comment:
       ```suggestion
   using DeviceMap =
   ```

##########
File path: src/relay/backend/te_compiler.h
##########
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file relay/backend/tir_compiler.h
+ *  * \brief Internal compilation layer which lowers Relay "primitive functions" to TIR PrimFns.
+ *
+ *
+ * This represents the new design of the Relay compilation flow and will replace the interface
+ * contained in compile_engine.h as we migrate towards a standard pass based lowering of
+ * Relay functions.
+ *
+ * This files provides an internal API which lowers Relay programs to components which
+ * can be combined with TVM produced kernels to compile an entire program.
+ *
+ * The result of lowering contains a combination of `runtime::Module`s produced by external
+ * compilers and a set of lowered PrimFns which can be code generated for targets.
+ */
+#ifndef TVM_RELAY_BACKEND_TE_COMPILER_H_
+#define TVM_RELAY_BACKEND_TE_COMPILER_H_
+
+#include <tvm/node/structural_equal.h>
+#include <tvm/node/structural_hash.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/op_strategy.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/module.h>
+#include <tvm/topi/elemwise.h>
+
+#include <functional>
+#include <string>
+#include <unordered_map>
+
+#include "../transforms/infer_layout_utils.h"
+#include "../transforms/pass_utils.h"
+#include "./te_compiler_cache.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+// TODO(@jroesch, @chrisS) these should be a tvm::Map for uniformity sake
+// we should a version of context which works in Map
+using TargetsMap = std::unordered_map<int, Target>;
+using DeviceContextMap =
+    std::unordered_map<Expr, tvm::Device, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>;
+using ProcessFn = std::function<void(Function)>;
+
+/*!
+ * \brief A compiler which lowers primitive Relay functions to tensor expressions
+ * and schdules them into TIR functions.
+ */
+class TECompilerNode : public Object {
+ public:
+  /*! \brief destructor */
+  virtual ~TECompilerNode() {}
+  /*!
+   * \brief Get lowered result.
+   * \param key The key to the cached function.
+   * \return The result.
+   */
+  virtual CachedFunc Lower(const CCacheKey& key) = 0;
+
+  virtual Map<String, IRModule> GetLoweredFunctions() = 0;
+  /*!
+   * \brief Just in time compile to get a PackedFunc.
+   * \param key The key to the cached function.
+   * \return The result.
+   */
+  virtual PackedFunc JIT(const CCacheKey& key) = 0;
+  /*!
+   * \brief Lower the shape function.
+   * \param key The key to the cached function.
+   * \return The result.
+   */
+  virtual CachedFunc LowerShapeFunc(const CCacheKey& key) = 0;
+  /*!
+   * \brief Lower the external function using external codegen tools.
+   * \return The runtime moduels for each needed external codegen tool.
+   */
+  virtual tvm::Array<tvm::runtime::Module> LowerExternalFunctions() = 0;
+
+  /*! \brief clear the cache. */
+  virtual void Clear() = 0;
+
+  // VisitAttrs
+  void VisitAttrs(AttrVisitor*) {}
+
+  static constexpr const char* _type_key = "relay.TECompiler";
+  TVM_DECLARE_FINAL_OBJECT_INFO(TECompilerNode, Object);
+};
+
+
+/*! \brief cache entry used in compile engine */
+class TECompiler : public ObjectRef {
+ public:
+  TECompiler();
+  explicit TECompiler(ObjectPtr<Object> n) : ObjectRef(n) {}
+  TECompilerNode* operator->() { return static_cast<TECompilerNode*>(get_mutable()); }
+  using ContainerType = TECompilerNode;
+  /*! \brief The global compile engine. */
+  TVM_DLL static TECompiler& Global();
+};
+
+struct LoweredModule {
+  IRModule main_module;
+  Map<String, IRModule> per_target_module;
+  Array<tvm::runtime::Module> external_mods;
+};
+
+LoweredModule LowerTE(const IRModule& module, TargetsMap targets,

Review comment:
       doc for this func

##########
File path: src/relay/backend/te_compiler_cache.h
##########
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file relay/backend/tec_compiler_utils.h

Review comment:
       rename this file to tec_compiler_utils.h?

##########
File path: src/relay/backend/te_compiler.h
##########
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file relay/backend/tir_compiler.h
+ *  * \brief Internal compilation layer which lowers Relay "primitive functions" to TIR PrimFns.
+ *
+ *
+ * This represents the new design of the Relay compilation flow and will replace the interface
+ * contained in compile_engine.h as we migrate towards a standard pass based lowering of
+ * Relay functions.
+ *
+ * This files provides an internal API which lowers Relay programs to components which
+ * can be combined with TVM produced kernels to compile an entire program.
+ *
+ * The result of lowering contains a combination of `runtime::Module`s produced by external
+ * compilers and a set of lowered PrimFns which can be code generated for targets.
+ */
+#ifndef TVM_RELAY_BACKEND_TE_COMPILER_H_
+#define TVM_RELAY_BACKEND_TE_COMPILER_H_
+
+#include <tvm/node/structural_equal.h>
+#include <tvm/node/structural_hash.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/op_strategy.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/module.h>
+#include <tvm/topi/elemwise.h>
+
+#include <functional>
+#include <string>
+#include <unordered_map>
+
+#include "../transforms/infer_layout_utils.h"
+#include "../transforms/pass_utils.h"
+#include "./te_compiler_cache.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+// TODO(@jroesch, @chrisS) these should be a tvm::Map for uniformity sake
+// we should a version of context which works in Map
+using TargetsMap = std::unordered_map<int, Target>;
+using DeviceContextMap =
+    std::unordered_map<Expr, tvm::Device, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>;
+using ProcessFn = std::function<void(Function)>;
+
+/*!
+ * \brief A compiler which lowers primitive Relay functions to tensor expressions
+ * and schdules them into TIR functions.
+ */
+class TECompilerNode : public Object {
+ public:
+  /*! \brief destructor */
+  virtual ~TECompilerNode() {}
+  /*!
+   * \brief Get lowered result.
+   * \param key The key to the cached function.
+   * \return The result.
+   */
+  virtual CachedFunc Lower(const CCacheKey& key) = 0;
+
+  virtual Map<String, IRModule> GetLoweredFunctions() = 0;

Review comment:
       add doc to this func

##########
File path: src/relay/backend/te_compiler_cache.cc
##########
@@ -0,0 +1,687 @@
+/*

Review comment:
       also rename this file to te_compiler_utils.cc?

##########
File path: src/relay/backend/te_compiler.h
##########
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file relay/backend/tir_compiler.h
+ *  * \brief Internal compilation layer which lowers Relay "primitive functions" to TIR PrimFns.
+ *
+ *
+ * This represents the new design of the Relay compilation flow and will replace the interface
+ * contained in compile_engine.h as we migrate towards a standard pass based lowering of
+ * Relay functions.
+ *
+ * This files provides an internal API which lowers Relay programs to components which
+ * can be combined with TVM produced kernels to compile an entire program.
+ *
+ * The result of lowering contains a combination of `runtime::Module`s produced by external
+ * compilers and a set of lowered PrimFns which can be code generated for targets.
+ */
+#ifndef TVM_RELAY_BACKEND_TE_COMPILER_H_
+#define TVM_RELAY_BACKEND_TE_COMPILER_H_
+
+#include <tvm/node/structural_equal.h>
+#include <tvm/node/structural_hash.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/op_strategy.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/module.h>
+#include <tvm/topi/elemwise.h>
+
+#include <functional>
+#include <string>
+#include <unordered_map>
+
+#include "../transforms/infer_layout_utils.h"
+#include "../transforms/pass_utils.h"
+#include "./te_compiler_cache.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+// TODO(@jroesch, @chrisS) these should be a tvm::Map for uniformity sake
+// we should a version of context which works in Map
+using TargetsMap = std::unordered_map<int, Target>;
+using DeviceContextMap =
+    std::unordered_map<Expr, tvm::Device, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>;
+using ProcessFn = std::function<void(Function)>;
+
+/*!
+ * \brief A compiler which lowers primitive Relay functions to tensor expressions
+ * and schdules them into TIR functions.
+ */
+class TECompilerNode : public Object {
+ public:
+  /*! \brief destructor */
+  virtual ~TECompilerNode() {}
+  /*!
+   * \brief Get lowered result.
+   * \param key The key to the cached function.
+   * \return The result.
+   */
+  virtual CachedFunc Lower(const CCacheKey& key) = 0;
+
+  virtual Map<String, IRModule> GetLoweredFunctions() = 0;
+  /*!
+   * \brief Just in time compile to get a PackedFunc.
+   * \param key The key to the cached function.
+   * \return The result.
+   */
+  virtual PackedFunc JIT(const CCacheKey& key) = 0;
+  /*!
+   * \brief Lower the shape function.
+   * \param key The key to the cached function.
+   * \return The result.
+   */
+  virtual CachedFunc LowerShapeFunc(const CCacheKey& key) = 0;
+  /*!
+   * \brief Lower the external function using external codegen tools.
+   * \return The runtime moduels for each needed external codegen tool.
+   */
+  virtual tvm::Array<tvm::runtime::Module> LowerExternalFunctions() = 0;
+
+  /*! \brief clear the cache. */
+  virtual void Clear() = 0;
+
+  // VisitAttrs

Review comment:
       remove this line




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch commented on a change in pull request #7518: [RFC]TECompiler: Staged refactor and removal of compile engine

Posted by GitBox <gi...@apache.org>.
jroesch commented on a change in pull request #7518:
URL: https://github.com/apache/tvm/pull/7518#discussion_r664942122



##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -0,0 +1,760 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "te_compiler.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <tuple>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "te_compiler.h"
+#include "te_compiler_cache.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+// TODO(@jroesch, @csullivan): declare directly elsewhere
+backend::StaticMemoryPlan GraphPlanMemory(const Function& func);
+
+namespace tec {
+
+using namespace tvm::relay::transform;
+
+TVM_REGISTER_OBJECT_TYPE(TECompilerNode);
+
+class TECompilerImpl : public TECompilerNode {
+ public:
+  // Lower the function.
+  CachedFunc Lower(const CCacheKey& key, std::function<String(String)> mangle_fn) {
+    return LowerInternal(key, mangle_fn)->cached_func;
+  }
+
+  CachedFunc Lower(const CCacheKey& key, const String mod_name) {
+    auto mangle_fn = [mod_name](String name) {
+      std::cout << "inner mod name" << mod_name << std::endl;
+      return runtime::get_name_mangled(mod_name, name);
+    };
+
+    return Lower(key, mangle_fn);
+  }
+
+  // For now, build one module per function.
+  PackedFunc JIT(const CCacheKey& key) final {
+    auto mangle_fn = [](String name) { return name; };
+    CCacheValue value = LowerInternal(key, mangle_fn);
+    if (value->packed_func != nullptr) {
+      return value->packed_func;
+    }
+    auto m = build(value->cached_func->funcs, key->target, Target(nullptr));
+    value->packed_func = m.GetFunction(value->cached_func->prim_fn_var->name_hint);
+    return value->packed_func;
+  }
+
+  CachedFunc LowerShapeFunc(const CCacheKey& key) final {
+    return LowerShapeFuncInternal(key)->cached_func;
+  }
+
+  Map<String, IRModule> GetLoweredFunctions() {
+    Map<String, IRModule> lowered_functions;
+    for (const auto& it : cache_) {
+      auto source_func = it.first;
+      auto lowered_func = it.second;
+      auto target = source_func->target;
+
+      if (!lowered_functions.count(target->str())) {
+        lowered_functions.Set(target->str(), IRModule(Map<GlobalVar, BaseFunc>({})));
+      }
+
+      lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs);
+    }
+    return lowered_functions;
+  }
+
+  Array<tvm::runtime::Module> LowerExternalFunctions() {
+    Array<tvm::runtime::Module> ret;
+    std::unordered_map<std::string, std::string> cached_symbol;
+    std::vector<CCacheKey> cached_ext_funcs;
+    for (const auto& it : cache_) {
+      auto src_func = it.first->source_func;
+      ICHECK(src_func.defined());
+      if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
+        auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
+        std::string code_gen_name = code_gen.value();
+        cached_ext_funcs.push_back(it.first);
+
+        auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+        ICHECK(symbol_name.defined()) << "No external symbol is set for:\n"
+                                      << AsText(src_func, false);
+
+        std::string sn = symbol_name.value();
+        if (cached_symbol.count(sn)) {
+          cached_symbol[sn] = code_gen_name;
+        } else {
+          ICHECK_NE(sn, code_gen_name)
+              << "Found duplicated symbol: " << sn << " for: " << code_gen_name;
+        }
+
+        std::string ext_name = "relay.ext." + code_gen_name;
+        auto pf = tvm::runtime::Registry::Get(ext_name);
+        ICHECK(pf) << "Failed to find the codegen tool for " << ext_name;
+        // No need to keep compiler attribute at this point, functions have been
+        // extracted for specific codegen.
+        src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue<ObjectRef>());
+        runtime::Module ext_mod = (*pf)(src_func);
+
+        ICHECK(ext_mod.defined()) << "No external runtime is generated.";
+        ret.push_back(ext_mod);
+      }
+    }
+
+    // No need to cache external functions as we collected them all to create
+    // external runtime modules.
+    for (const auto& it : cached_ext_funcs) {
+      cache_.erase(it);
+    }
+    return ret;
+  }
+
+  void Clear() final { cache_.clear(); }
+
+  // List all items in the cache.
+  Array<ObjectRef> ListItems() {
+    std::lock_guard<std::mutex> lock(mutex_);
+    Array<ObjectRef> items;
+    for (auto& kv : cache_) {
+      items.push_back(kv.first);
+      items.push_back(kv.second);
+    }
+    return items;
+  }
+
+  /*!
+   * \brief Get the cache key of the function that is being lowered currently
+   * \return the cache key
+   */
+  CCacheKey GetCurrentCCacheKey() { return cur_ccache_key_; }
+
+ private:
+  // implement lowered func
+  CCacheValue LowerInternal(const CCacheKey& key, std::function<String(String)> mangle_fn) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    CCacheValue value;
+    auto it = cache_.find(key);
+    if (it != cache_.end()) {
+      it->second->use_count += 1;
+      if (it->second->cached_func.defined()) return it->second;
+      value = it->second;
+    } else {
+      value = CCacheValue(make_object<CCacheValueNode>());
+      value->use_count = 1;
+      cache_[key] = value;
+    }
+    cur_ccache_key_ = key;
+
+    // No need to lower external functions for now. We will invoke the external
+    // codegen tool once and lower all functions together.
+    if (key->source_func->GetAttr<String>(attr::kCompiler).defined()) {
+      auto ir_module = IRModule();
+      const auto name_node = key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+      ICHECK(name_node.defined()) << "External function has not been attached a name yet.";
+      auto func_name = GetUniqueName(name_node.value(), &name_map_);
+      auto target = Target("ext_dev");
+      auto global_var = GlobalVar(func_name);
+      global_var->checked_type_ = key->source_func->checked_type();
+      value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule(), {}, ir_module);
+      return value;
+    }
+
+    // Enforce use the target.
+    With<Target> target_scope(key->target);
+
+    ICHECK(!value->cached_func.defined());
+    auto cfunc = PrimFuncFor(key->source_func, key->target, [&](std::string name) {
+      auto mangled = mangle_fn(name);
+      std::cout << "Mangled: " << mangled << std::endl;

Review comment:
       Remove this.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch commented on a change in pull request #7518: [RFC]TECompiler: Staged refactor and removal of compile engine

Posted by GitBox <gi...@apache.org>.
jroesch commented on a change in pull request #7518:
URL: https://github.com/apache/tvm/pull/7518#discussion_r609055222



##########
File path: src/relay/backend/graph_executor_codegen.cc
##########
@@ -349,65 +380,47 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
     return AddNode(node, GetRef<Expr>(op));
   }
 
-  std::vector<GraphNodeRef> VisitExpr_(const CallNode* op) override {
-    Expr expr = GetRef<Expr>(op);
-    Function func;
-    if (op->op.as<OpNode>()) {
-      LOG(FATAL) << "Operators should be transformed away; try applying"
-                 << "the fuse_ops transformation to the expression.";
-    } else if (op->op.as<GlobalVarNode>()) {
-      LOG(FATAL) << "Not implemented";
-    } else if (op->op.as<FunctionNode>()) {
-      func = GetRef<Function>(op->op.as<FunctionNode>());
-    } else {
-      LOG(FATAL) << "TVM runtime does not support calls to " << op->op->GetTypeKey();
-    }
-    if (!func->HasNonzeroAttr(attr::kPrimitive)) {
-      LOG(FATAL) << "TVM only support calls to primitive functions "
-                 << "(i.e functions composed of fusable operator invocations)";
-    }
+  std::vector<GraphNodeRef> VisitExpr_(const CallNode* call_node) override {
+    relay::Call call = GetRef<Call>(call_node);
+    if (auto global_node = call->op.as<GlobalVarNode>()) {
 
-    auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey");
-    auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower");
-    Target target;
-    // Handle external function
-    if (func->GetAttr<String>(attr::kCompiler).defined()) {
-      target = Target("ext_dev");
-      CCacheKey key = (*pf0)(func, target);
-      CachedFunc ext_func = (*pf1)(compile_engine_, key);
-      ICHECK(ext_func.defined()) << "External function is not defined.";
-      UpdateConstants(func, &params_);
-      return GraphAddCallNode(op, ext_func->func_name, ext_func->func_name);
-    }
+      auto prim_fn_name = global_node->name_hint;
 
-    ICHECK_GE(storage_device_map_.count(expr), 0);
-    auto& device_type = storage_device_map_[expr][1];
-    auto call_dev_type = device_type[0]->value;
-    // Normal Relay Function
-    if (targets_.size() == 1) {
-      // homogeneous execution.
-      const auto& it = targets_.begin();
-      target = (*it).second;
-    } else {
-      // heterogeneous execution.
-      std::string call_dev_name;
-      if (call_dev_type == 0) {
-        call_dev_name = "llvm";
+      Target target;
+
+      // // Handle external function
+      // if (func->GetAttr<String>(attr::kCompiler).defined()) {
+      //   UpdateConstants(func, &params_);
+      //   return GraphAddCallNode(call_node, prim_fn_name, prim_fn_name);

Review comment:
       This code has been put inside the lowering the below case should catch it now. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch merged pull request #7518: [RFC]TECompiler: Staged refactor and removal of compile engine

Posted by GitBox <gi...@apache.org>.
jroesch merged pull request #7518:
URL: https://github.com/apache/tvm/pull/7518


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch commented on a change in pull request #7518: [RFC]TECompiler: Staged refactor and removal of compile engine

Posted by GitBox <gi...@apache.org>.
jroesch commented on a change in pull request #7518:
URL: https://github.com/apache/tvm/pull/7518#discussion_r664943202



##########
File path: tests/python/relay/test_op_fast_math.py
##########
@@ -14,6 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import pdb

Review comment:
       Remove

##########
File path: tests/python/relay/test_op_fast_math.py
##########
@@ -44,7 +45,6 @@ def test_apply(relay_op, name, f_numpy, low, high, step, dtype="float32"):
         func_name = "tvmgen_default_fused_" + name
         # When there're multiple targets in tvm.testing.parametrize_targets, the function
         # built will have a "_1" in function name
-        assert func_name in graph

Review comment:
       Restore




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tkonolige commented on a change in pull request #7518: [RFC]TECompiler: Staged refactor and removal of compile engine

Posted by GitBox <gi...@apache.org>.
tkonolige commented on a change in pull request #7518:
URL: https://github.com/apache/tvm/pull/7518#discussion_r661762009



##########
File path: src/relay/backend/graph_executor_codegen.cc
##########
@@ -420,170 +387,57 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
     return fields;
   }
 
-  std::vector<GraphNodeRef> GraphAddCallNode(const CallNode* op, const std::string& op_name,
-                                             const std::string& func_name, GraphAttrs attrs) {
+  bool ShareSameStorage(const Expr& lhs, const Expr& rhs) {
+    StorageInfo lit = GetStorageInfo(lhs);
+    StorageInfo rit = GetStorageInfo(rhs);
+    int64_t lhs_storage_id = lit->storage_ids[0];
+    int64_t rhs_storage_id = rit->storage_ids[0];
+    return lhs_storage_id == rhs_storage_id;
+  }
+
+  std::vector<GraphNodeRef> GraphAddCallNode(const CallNode* op, const std::string& func_name,
+                                             GraphAttrs op_attrs) {
     std::vector<GraphNodeRef> inputs;
     for (auto arg : op->args) {
       auto res = VisitExpr(arg);
       for (auto nr : res) {
         inputs.push_back(nr);
       }
     }
-    auto node = GraphOpNode::make_node_ptr(op_name, GraphAttrs(), func_name, inputs, attrs);
-    return AddNode(node, GetRef<Expr>(op));
-  }
-
-  bool ShareSameStorage(const Expr& lhs, const Expr& rhs) {
-    auto lit = storage_device_map_.find(lhs);
-    auto rit = storage_device_map_.find(rhs);
-    ICHECK(lit != storage_device_map_.end());
-    ICHECK(rit != storage_device_map_.end());
-    int64_t lhs_storage_id = ((*lit).second)[0][0]->value;
-    int64_t rhs_storage_id = ((*rit).second)[0][0]->value;
-    return lhs_storage_id == rhs_storage_id;
-  }
 
-  /*!
-   * \brief Obtain the Target from the device type.
-   * If homogenous compilation, this will return the only target.
-   * If heteregenous compilation, this will select associated using the targets_ Map.
-   *
-   * \param dev_type
-   * \return Target
-   */
-  Target GetTargetFromInteger(int64_t dev_type) {
-    if (targets_.size() == 1) {
-      // homogeneous execution.
-      const auto& it = targets_.begin();
-      return (*it).second;
-    } else {
-      // heterogeneous execution.
-      std::string call_dev_name;
-      if (dev_type == 0) {
-        call_dev_name = "llvm";
-      } else {
-        call_dev_name = runtime::DeviceName(dev_type);
-      }
-      if (targets_.count(dev_type) == 0) {
-        LOG(FATAL) << "No target is provided for device " << call_dev_name;
-      }
-      return targets_[dev_type];
+    /// An adapted version of the storage optimization for the time being.
+    bool reshape_only = false;
+    if (op->attrs.defined() && op->attrs.as<TIRCallAttrs>()) {
+      reshape_only = true;
     }
-  }
 
-  /*!
-   * \brief Update the function metadata for a given cached function and its relay
-   * primitive function.
-   *
-   * \param cfunc The cached function as provided the by the compile engine
-   * \param relay_func The source relay primitive function
-   * \param relay_target The target associated with relay primitive function
-   */
-  void UpdateFunctionMetadata(const CachedFunc& cfunc, const Function& relay_func,
-                              const Target& relay_target) {
-    auto fi_node = make_object<FunctionInfoNode>();
-    for (const auto& kv : cfunc->funcs->functions) {
-      auto primfunc = Downcast<tir::PrimFunc>(kv.second);
-      auto workspace_byte_alignment = relay_target->GetAttr<Integer>("workspace-byte-alignment")
-                                          .value_or(tvm::runtime::kDefaultWorkspaceAlignment);
-      Integer workspace_size = CalculateWorkspaceBytes(primfunc, workspace_byte_alignment);
-      Target primfunc_target = relay_target;
-      if (primfunc->attrs->dict.count("target")) {
-        primfunc_target = Downcast<Target>(primfunc->attrs->dict["target"]);
-      }
-      fi_node->workspace_sizes.Set(primfunc_target, workspace_size);
-      // Calculating size for I/O
-      for (auto const& param : primfunc->params) {
-        auto p_shape = primfunc->buffer_map[param]->shape;
-        int num_of_elements = 1;
-        for (const auto& dim_index_expr : p_shape) {
-          if (dim_index_expr->IsInstance<IntImmNode>()) {
-            num_of_elements *= dim_index_expr.as<IntImmNode>()->value;
-          } else {
-            // If shape is dynamic, we cannot calculate workspace in compile time.
-            num_of_elements = 0;
-          }
-        }
-        int element_size = primfunc->buffer_map[param]->dtype.bytes();
-        fi_node->io_sizes.Set(primfunc_target, element_size * num_of_elements);
-      }
-      fi_node->constant_sizes.Set(primfunc_target, 0);
-      fi_node->tir_primfuncs.Set(primfunc_target, primfunc);
-      fi_node->relay_primfuncs.Set(primfunc_target, relay_func);
-    }
-    function_metadata_.Set(cfunc->func_name, FunctionInfo(fi_node));
-  }
-
-  std::vector<GraphNodeRef> VisitExpr_(const CallNode* op) override {
-    Expr expr = GetRef<Expr>(op);
-    Function func;
-    if (op->op.as<OpNode>()) {
-      LOG(FATAL) << "Operators should be transformed away; try applying"
-                 << "the fuse_ops transformation to the expression.";
-    } else if (op->op.as<GlobalVarNode>()) {
-      LOG(FATAL) << "Not implemented";
-    } else if (op->op.as<FunctionNode>()) {
-      func = GetRef<Function>(op->op.as<FunctionNode>());
-    } else {
-      LOG(FATAL) << "TVM runtime does not support calls to " << op->op->GetTypeKey();
-    }
-    if (!func->HasNonzeroAttr(attr::kPrimitive)) {
-      LOG(FATAL) << "TVM only support calls to primitive functions "
-                 << "(i.e functions composed of fusable operator invocations)";
-    }
-
-    // Copy attrs from function into the graph node
-    // For now we only handle strings
-    GraphAttrs attrs;
-    for (auto p : func->attrs->dict) {
-      if (p.second.as<StringObj>()) {
-        attrs[p.first] = std::string(Downcast<String>(p.second));
-      }

Review comment:
       You don't want to remove this.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch edited a comment on pull request #7518: [RFC]TECompiler: Staged refactor and removal of compile engine

Posted by GitBox <gi...@apache.org>.
jroesch edited a comment on pull request #7518:
URL: https://github.com/apache/tvm/pull/7518#issuecomment-805600377


   cc @mbaret 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org