You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/10/19 19:05:12 UTC

[GitHub] [tvm] mbs-octoml commented on a change in pull request #9312: [DRAFT] Change Call with TIRCallAttrs to call_lowered op

mbs-octoml commented on a change in pull request #9312:
URL: https://github.com/apache/tvm/pull/9312#discussion_r732162456



##########
File path: src/relay/backend/graph_plan_memory.cc
##########
@@ -147,6 +147,12 @@ class StorageAllocaBaseVisitor : public transform::DeviceAwareExprVisitor {
    */
   const std::vector<StorageToken*>& GetToken(const Expr& expr) {
     this->VisitExpr(expr);
+    // Return empty if called on a Function
+    // OK actually looks like we do want to do stuff for function nodes?
+    if (expr->checked_type().as<FuncTypeNode>()) {
+      static const std::vector<StorageToken*> empty;
+      return empty;

Review comment:
       nit: just return {} should work.

##########
File path: src/relay/op/vm/vm.cc
##########
@@ -195,6 +195,46 @@ RELAY_REGISTER_OP("vm.invoke_tvm_op")
                              return {topi::identity(inputs[0])};
                            });
 
+// call_lowered
+bool CallTIRRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                    const TypeReporter& reporter) {
+  // Types = [func, args, ret_type]
+  ICHECK_EQ(types.size(), 3u);
+  auto func_type = types[0].as<FuncTypeNode>();
+  ICHECK(func_type != nullptr) << "input must be operator with known type";
+  auto input_type = types[1].as<TupleTypeNode>();
+  ICHECK(input_type != nullptr)
+      << "internal invariant violated: call_lowered inputs must be a tuple";
+  reporter->Assign(types[2], func_type->ret_type);
+  return true;
+}
+
+Expr CallTIR(Expr func, Expr inputs, Attrs attrs) {
+  ICHECK(func.as<GlobalVarNode>()) << "Function to call should be GlobalVarNode, but got " << func->GetTypeKey();
+  return Call(Op::Get("call_lowered"), {func, inputs}, attrs);
+}
+
+TVM_REGISTER_GLOBAL("relay.op.call_lowered").set_body_typed(CallTIR);
+
+RELAY_REGISTER_OP("call_lowered")
+    .describe(R"code(Invoke an operation compiled by TVM.)code" TVM_ADD_FILELINE)
+    .set_num_inputs(2)
+    .add_argument("op", "Function", "The operation to call")
+    .add_argument("ins", "Tuple", "The input tensors.")
+    .add_type_rel("CallTIRRel", CallTIRRel)
+    .set_support_level(10)
+    .set_attr<TOpPattern>("TOpPattern", kOpaque)

Review comment:
       As we discussed all the 'pretend this is actually  a lowerable primitive' stuff can go.
   I'd add a set_attrs_type_key annotation connecting this to the TIRCallAttrs so parsing will work.
   Should probably rename those attrs to match the operator name I guess.

##########
File path: src/relay/op/vm/vm.cc
##########
@@ -195,6 +195,46 @@ RELAY_REGISTER_OP("vm.invoke_tvm_op")
                              return {topi::identity(inputs[0])};
                            });
 
+// call_lowered
+bool CallTIRRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                    const TypeReporter& reporter) {
+  // Types = [func, args, ret_type]
+  ICHECK_EQ(types.size(), 3u);
+  auto func_type = types[0].as<FuncTypeNode>();
+  ICHECK(func_type != nullptr) << "input must be operator with known type";
+  auto input_type = types[1].as<TupleTypeNode>();
+  ICHECK(input_type != nullptr)
+      << "internal invariant violated: call_lowered inputs must be a tuple";
+  reporter->Assign(types[2], func_type->ret_type);
+  return true;
+}
+
+Expr CallTIR(Expr func, Expr inputs, Attrs attrs) {

Review comment:
       I'd suggest putting the op registration & helpers in relay/op/calls.{h,cc}.
   You'll probably want a CallLoweredOp() helper in there that caches the op in a  static to avoid all the dynamic Op::Gets.
   I've not yet looked at the changes to support pulling these apart, but it may be a helper to deconstruct them would also be worthwhile? For "on_device" at least I found that helped manage the tedious code dup -- see relay/op/memory/on_device.{h,cc}.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org