You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/06/08 18:00:41 UTC

[GitHub] [tvm] mbs-octoml commented on a diff in pull request #11619: [BYOC] RelayToTIR custom codegen passes can still depend on dynamic shape functions

mbs-octoml commented on code in PR #11619:
URL: https://github.com/apache/tvm/pull/11619#discussion_r892703621


##########
src/relay/backend/te_compiler.cc:
##########
@@ -566,100 +566,128 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator {
         return itr->second;
       }
     } else if (const auto* function_node = expr.as<FunctionNode>()) {
-      if (!function_node->HasNonzeroAttr(attr::kPrimitive)) {
-        // Not marked as primitive by FuseOps.
-        return {};
-      }
-      if (const auto* call_node = function_node->body.as<CallNode>()) {
-        if (call_node->op == debug_op_) {
-          // Debug 'primitives' are not lowered.
-          return {};
+      if (function_node->HasNonzeroAttr(attr::kExtern)) {
+        // We have a regular call to an 'extern' function. The call itself needs to be rewritten
+        // to call_lowered form, and any required dynamic shape functions generated and
+        // cross-linked.
+        return GetRef<Function>(function_node);
+      } else if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
+        if (const auto* call_node = function_node->body.as<CallNode>()) {
+          if (call_node->op == debug_op_) {
+            // Debug 'primitives' are not lowered.
+            return {};
+          }
         }
+        // We have a regular call to a 'primitive' function (possibly with a 'Compiler' attribute).
+        // We need to lower and rewrite the call.
+        return GetRef<Function>(function_node);
+      } else {
+        // Not marked as primitive during partitioning or TVM fusion.
+        return {};
       }
-      return GetRef<Function>(function_node);
     } else {
       return {};
     }
   }
 
   /*!
-   * \brief Lowers the primitive function \p func to TIR for ultimate execution
-   * on a device with configuration \p target. Returns the global var bound
-   * to the TIR implementation, and attributes to attach to the call to identify it as
-   * a TIR call.
+   * \brief Returns a 'call_lowered' call to \p prim_fn_var with \p args and \p span with all the
+   * required attributes filled in. Generally \p prim_fn_var will correspond to the lowered or
+   * externally codegen-ed form of \p original_function, where \p lowered_functions binds all
+   * the required lowered functions.
+   *
+   * The call's attributes will capture:
+   *  - Any attributes on the original_function.
+   *  - All the lowered functions.
+   *    TODO(mbs): Pretty sure that's no longer needed.
+   *  - Details needed to cross-link the call to it's dynamic shape function, if any.
    */
-  Expr MakeLoweredCall(Function func, Array<Expr> visited_args, Span span, Target target) {
-    CCacheKey key = CCacheKey(func, target);
-    CachedFunc cfunc = compiler_->Lower(key, module_name_);
-    ICHECK(cfunc.defined());
-
-    auto opt_compiler = func->GetAttr<String>(attr::kCompiler);
+  Expr MakeLoweredCall(const BaseFunc& original_function, const GlobalVar& prim_fn_var,
+                       Array<Expr> args, Span span, const Target& target,
+                       const Map<GlobalVar, BaseFunc>& lowered_functions) {
+    auto opt_compiler = original_function->GetAttr<String>(attr::kCompiler);
 
     // Add some metadata on top of the *original function* and invoke the callback so it can
     // be captured.
     // TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT
     Map<GlobalVar, tir::PrimFunc> prim_fns;
     Array<GlobalVar> all_prim_fn_vars;
-    for (const auto& kv : cfunc->funcs->functions) {
+    for (const auto& kv : lowered_functions) {
       if (opt_compiler) {
-        // We expect just the original func but with just the ExternalSymbol attribute signaling
-        // the function (will be) compiled externally.
+        // We expect the original function to have just the "Extern" attribute signaling the
+        // function (will be) compiled externally.
         ICHECK(kv.second.as<FunctionNode>())
             << PrettyPrint(kv.first) << " must be bound to an (external) Function";
       } else {
-        // We expect one or more PrimFuncs, one of which corresponds to 'the' lowered primitive
-        // (and the rest in support of that via tir::Calls).
+        // We expect one or more PrimFuncs, one of which corresponds to 'the' lowered primitive,
+        // and the rest are in support of that via tir::Calls.
         ICHECK(kv.second.as<tir::PrimFuncNode>())
             << PrettyPrint(kv.first) << " must be bound to a PrimFunc";
         prim_fns.Set(kv.first, Downcast<tir::PrimFunc>(kv.second));
         all_prim_fn_vars.push_back(kv.first);
       }
     }
-    Function func_with_metadata = func;
-    func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", cfunc->prim_fn_var);
-    func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns);
-    func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, cfunc->target);
-    this->process_fn_(func_with_metadata);
 
+    // Alas, WithAttr cannot work with base classes.

Review Comment:
   yeah i was disappointed by this too



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org