You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/07/09 19:00:54 UTC

[GitHub] [incubator-tvm] icemelon9 commented on a change in pull request #5958: [REFACTOR][RELAY] Move invoke_tvm_op and shape_func to vm dialect

icemelon9 commented on a change in pull request #5958:
URL: https://github.com/apache/incubator-tvm/pull/5958#discussion_r452427478



##########
File path: src/relay/op/vm/vm.cc
##########
@@ -54,5 +58,128 @@ TVM_REGISTER_GLOBAL("relay.op.vm.shape_of").set_body_typed([](Expr expr) {
   return Call(op, {expr}, Attrs(attrs), {});
 });
 
+TVM_REGISTER_GLOBAL("relay.op.vm.shape_func")
+    .set_body_typed([](Expr func, Expr inputs, Expr outputs, Array<tvm::Integer> is_input) {
+      static const Op& op = Op::Get("vm.shape_func");
+      auto attrs = make_object<ShapeFuncAttrs>();
+      attrs->is_input = is_input;
+      return Call(op, {func, inputs, outputs}, Attrs(attrs), {});
+    });
+
+bool ShapeFuncRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                  const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 4u);
+  auto shape_func_attrs = attrs.as<ShapeFuncAttrs>();
+  CHECK(shape_func_attrs != nullptr) << "Internal compiler error";
+
+  auto func_type = types[0].as<FuncTypeNode>();
+  CHECK(func_type != nullptr);
+
+  auto tuple = TupleType(func_type->arg_types);
+  auto in_types = FlattenTupleType(tuple);
+  auto out_types = FlattenTupleType(func_type->ret_type);
+  Array<Integer> is_input;
+  for (size_t i = 0; i < func_type->arg_types.size(); ++i) {
+    auto const& aty = func_type->arg_types[i];
+    size_t num_types = 1;
+    if (aty.as<TupleTypeNode>()) {
+      num_types = FlattenTupleType(aty).size();
+    }
+    for (size_t j = 0; j < num_types; ++j) {
+      is_input.push_back(shape_func_attrs->is_input[i]);
+    }
+  }
+
+  Array<Type> shape_func_ins, shape_func_outs;
+  for (size_t i = 0; i < in_types.size(); i++) {
+    auto in_type = in_types[i];
+
+    if (is_input[i]) {
+      shape_func_ins.push_back(in_type);
+    } else {
+      auto shape = RankShape(in_type->shape);
+      shape_func_ins.push_back(TensorType(shape, DataType::Int(64)));
+    }
+  }
+
+  for (auto out_type : out_types) {
+    auto rank_shape = RankShape(out_type->shape);
+    shape_func_outs.push_back(TensorType(rank_shape, DataType::Int(64)));
+  }
+
+  auto input_type = TupleType(shape_func_ins);
+  auto output_type = TupleType(shape_func_outs);
+
+  reporter->Assign(types[1], input_type);
+  reporter->Assign(types[2], output_type);
+  reporter->Assign(types[3], TupleType::Empty());
+
+  return true;
+}
+
+RELAY_REGISTER_OP("vm.shape_func")
+    .describe(R"code(Get the shape of a tensor.)code" TVM_ADD_FILELINE)
+    .set_num_inputs(3)
+    .add_argument("tensor", "Tensor", "The tensor to retrieve the shape for.")
+    .add_type_rel("ShapeFuncRel", ShapeFuncRel)
+    .set_support_level(10)
+    .set_attr<TOpPattern>("TOpPattern", kOpaque)
+    .set_attr<TOpIsStateful>("TOpIsStateful", false)
+    .set_attr<TNonComputational>("TNonComputational", true)
+    .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
+    .set_attr<FTVMCompute>("FTVMCompute",
+                           [](const Attrs& attrs, const Array<te::Tensor>& inputs,
+                              const Type& out_dtype) -> Array<te::Tensor> {
+                             return {topi::identity(inputs[0])};
+                           });
+
+bool InvokeTVMOPRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                    const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 4u);
+  auto func_type = types[0].as<FuncTypeNode>();
+  CHECK(func_type != nullptr) << "input must be operator with known type";
+  auto input_type = types[1].as<TupleTypeNode>();
+  auto output_type = types[2].as<TupleTypeNode>();
+  CHECK(input_type != nullptr)
+      << "internal invariant violated: invoke_tvm_op inputs must be a tuple";
+  CHECK(output_type != nullptr)
+      << "internal invariant violated: invoke_tvm_op outputs must be a tuple";
+  Type ex_output;
+  if (func_type->ret_type.as<TensorTypeNode>()) {
+    ex_output = TupleType({func_type->ret_type});
+  } else {
+    CHECK(func_type->ret_type.as<TupleTypeNode>()) << "should be tuple type";
+    ex_output = func_type->ret_type;
+  }
+  auto ex_input = TupleType(func_type->arg_types);
+  reporter->Assign(ex_input, GetRef<Type>(input_type));
+  reporter->Assign(ex_output, GetRef<Type>(output_type));
+  reporter->Assign(types[3], TupleType::Empty());
+  return true;
+}
+
+TVM_REGISTER_GLOBAL("relay.op.vm.invoke_tvm_op")
+    .set_body_typed([](Expr func, Expr inputs, Expr outputs) {
+      return Call(Op::Get("vm.invoke_tvm_op"), {func, inputs, outputs}, Attrs());
+    });
+
+RELAY_REGISTER_OP("vm.invoke_tvm_op")
+    .describe(R"code(Invoke an operation compiled by TVM.)code" TVM_ADD_FILELINE)
+    .set_num_inputs(3)
+    .add_argument("op", "Function", "The operation to call")
+    .add_argument("ins", "Tuple", "The input tensors.")
+    .add_argument("outs", "Tuple", "The output tensors.")
+    .add_type_rel("InvokeTVMOP", InvokeTVMOPRel)

Review comment:
       ```suggestion
       .add_type_rel("InvokeTVMOp", InvokeTVMOpRel)
   ```

##########
File path: src/relay/op/vm/vm.cc
##########
@@ -54,5 +58,128 @@ TVM_REGISTER_GLOBAL("relay.op.vm.shape_of").set_body_typed([](Expr expr) {
   return Call(op, {expr}, Attrs(attrs), {});
 });
 
+TVM_REGISTER_GLOBAL("relay.op.vm.shape_func")
+    .set_body_typed([](Expr func, Expr inputs, Expr outputs, Array<tvm::Integer> is_input) {
+      static const Op& op = Op::Get("vm.shape_func");
+      auto attrs = make_object<ShapeFuncAttrs>();
+      attrs->is_input = is_input;
+      return Call(op, {func, inputs, outputs}, Attrs(attrs), {});
+    });
+
+bool ShapeFuncRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                  const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 4u);
+  auto shape_func_attrs = attrs.as<ShapeFuncAttrs>();
+  CHECK(shape_func_attrs != nullptr) << "Internal compiler error";
+
+  auto func_type = types[0].as<FuncTypeNode>();
+  CHECK(func_type != nullptr);
+
+  auto tuple = TupleType(func_type->arg_types);
+  auto in_types = FlattenTupleType(tuple);
+  auto out_types = FlattenTupleType(func_type->ret_type);
+  Array<Integer> is_input;
+  for (size_t i = 0; i < func_type->arg_types.size(); ++i) {
+    auto const& aty = func_type->arg_types[i];
+    size_t num_types = 1;
+    if (aty.as<TupleTypeNode>()) {
+      num_types = FlattenTupleType(aty).size();
+    }
+    for (size_t j = 0; j < num_types; ++j) {
+      is_input.push_back(shape_func_attrs->is_input[i]);
+    }
+  }
+
+  Array<Type> shape_func_ins, shape_func_outs;
+  for (size_t i = 0; i < in_types.size(); i++) {
+    auto in_type = in_types[i];
+
+    if (is_input[i]) {
+      shape_func_ins.push_back(in_type);
+    } else {
+      auto shape = RankShape(in_type->shape);
+      shape_func_ins.push_back(TensorType(shape, DataType::Int(64)));
+    }
+  }
+
+  for (auto out_type : out_types) {
+    auto rank_shape = RankShape(out_type->shape);
+    shape_func_outs.push_back(TensorType(rank_shape, DataType::Int(64)));
+  }
+
+  auto input_type = TupleType(shape_func_ins);
+  auto output_type = TupleType(shape_func_outs);
+
+  reporter->Assign(types[1], input_type);
+  reporter->Assign(types[2], output_type);
+  reporter->Assign(types[3], TupleType::Empty());
+
+  return true;
+}
+
+RELAY_REGISTER_OP("vm.shape_func")
+    .describe(R"code(Get the shape of a tensor.)code" TVM_ADD_FILELINE)
+    .set_num_inputs(3)
+    .add_argument("tensor", "Tensor", "The tensor to retrieve the shape for.")
+    .add_type_rel("ShapeFuncRel", ShapeFuncRel)
+    .set_support_level(10)
+    .set_attr<TOpPattern>("TOpPattern", kOpaque)
+    .set_attr<TOpIsStateful>("TOpIsStateful", false)
+    .set_attr<TNonComputational>("TNonComputational", true)
+    .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
+    .set_attr<FTVMCompute>("FTVMCompute",
+                           [](const Attrs& attrs, const Array<te::Tensor>& inputs,
+                              const Type& out_dtype) -> Array<te::Tensor> {
+                             return {topi::identity(inputs[0])};
+                           });
+
+bool InvokeTVMOPRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,

Review comment:
       ```suggestion
   bool InvokeTVMOpRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   ```

##########
File path: python/tvm/relay/op/vm/vm.py
##########
@@ -33,3 +33,47 @@ def shape_of(expr):
         The expression with the evaluated tensor shape.
     """
     return _ffi_api.shape_of(expr)
+
+
+def invoke_tvm_op(func, inputs, outputs):
+    """Call a primitive function with the TVM operator calling convention.
+
+    Parameters
+    ----------
+    func : tvm.relay.Expr
+        The input expr.
+
+    inputs : tvm.relay.Expr
+        A tuple of the inputs to pass to the TVM function.
+
+    outputs : tvm.relay.Expr
+        A tuple of the outputs to pass to the TVM function.
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The invoke_tvm_op call node.
+    """
+    return _ffi_api.invoke_tvm_op(func, inputs, outputs)
+
+
+def shape_func(func, inputs, outputs, dependent=False):
+    """Invoke the shape function of the passed function.
+
+    Parameters
+    ----------
+    func : tvm.relay.Expr
+        The primitive function from which to compute the shape function.
+
+    inputs : tvm.relay.Tuple
+        The tupled inputs.
+
+    outputs : tvm.relay.Tuple
+        The tupled outputs.
+

Review comment:
       Add `dependent` in the doc




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org