You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/02/24 10:42:03 UTC

[GitHub] [tvm] manupa-arm commented on a change in pull request #10348: [CMSIS-NN] Fix scalar to tensor constant pass when checked type is TupleType

manupa-arm commented on a change in pull request #10348:
URL: https://github.com/apache/tvm/pull/10348#discussion_r813752267



##########
File path: src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc
##########
@@ -117,21 +117,19 @@ class ScalarToTensorConstantMutator : public MixedModeMutator {
     for (uint32_t i = 0; i < call->args.size(); ++i) {
       Expr arg = call->args[i];
       new_args.push_back(arg);
-      if (!arg->checked_type_.defined()) {
-        continue;
-      }
-      auto* arg_type = arg->type_as<TensorTypeNode>();
-      if (arg_type->shape.size() != 0 || arg.as<ConstantNode>()) {
+      const auto* arg_var = arg.as<VarNode>();
+      const auto* arg_type = arg->checked_type_.as<TensorTypeNode>();
+      if (!arg_var || !arg_type || arg_type->shape.size() != 0) {

Review comment:
       nit : could use "->IsInstance<VarNode>()" if arg_var and arg_type is not used anymore.

##########
File path: src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc
##########
@@ -142,19 +140,18 @@ class ScalarToTensorConstantMutator : public MixedModeMutator {
     for (uint32_t i = 0; i < call->args.size(); ++i) {
       new_args.push_back(call->args[i]);
       Expr scalar_arg = call->args[i];
-      if (!scalar_arg->checked_type_.defined()) {
-        continue;
-      }
-      Array<PrimExpr> scalar_shape = scalar_arg->type_as<TensorTypeNode>()->shape;
-      if (scalar_shape.size() != 0 || scalar_arg.as<ConstantNode>() == nullptr) {
+      const auto* scalar_const = scalar_arg.as<ConstantNode>();
+      const auto* scalar_type = scalar_arg->checked_type_.as<TensorTypeNode>();

Review comment:
       same nit as above

##########
File path: src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc
##########
@@ -142,19 +140,18 @@ class ScalarToTensorConstantMutator : public MixedModeMutator {
     for (uint32_t i = 0; i < call->args.size(); ++i) {
       new_args.push_back(call->args[i]);
       Expr scalar_arg = call->args[i];
-      if (!scalar_arg->checked_type_.defined()) {
-        continue;
-      }
-      Array<PrimExpr> scalar_shape = scalar_arg->type_as<TensorTypeNode>()->shape;
-      if (scalar_shape.size() != 0 || scalar_arg.as<ConstantNode>() == nullptr) {
+      const auto* scalar_const = scalar_arg.as<ConstantNode>();
+      const auto* scalar_type = scalar_arg->checked_type_.as<TensorTypeNode>();
+      if (!scalar_const || !scalar_type || scalar_type->shape.size() != 0) {
         continue;
       }
       int tensor_arg_id = (i + 1) % 2;
       Expr tensor_arg = call->args[tensor_arg_id];
-      if (!tensor_arg->checked_type_.defined()) {
+      const auto* tensor_type_node = tensor_arg->checked_type_.as<TensorTypeNode>();
+      if (!tensor_type_node) {
         continue;
       }
-      TensorType tensor_type = GetRef<TensorType>(tensor_arg->type_as<TensorTypeNode>());
+      TensorType tensor_type = GetRef<TensorType>(tensor_type_node);

Review comment:
       same nit as above

##########
File path: src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc
##########
@@ -117,21 +117,19 @@ class ScalarToTensorConstantMutator : public MixedModeMutator {
     for (uint32_t i = 0; i < call->args.size(); ++i) {
       Expr arg = call->args[i];
       new_args.push_back(arg);
-      if (!arg->checked_type_.defined()) {
-        continue;
-      }
-      auto* arg_type = arg->type_as<TensorTypeNode>();
-      if (arg_type->shape.size() != 0 || arg.as<ConstantNode>()) {
+      const auto* arg_var = arg.as<VarNode>();
+      const auto* arg_type = arg->checked_type_.as<TensorTypeNode>();
+      if (!arg_var || !arg_type || arg_type->shape.size() != 0) {
         continue;
       }
       String arg_name = arg.as<VarNode>()->name_hint();
       int tensor_arg_id = (i + 1) % 2;
       Expr tensor_arg = call->args[tensor_arg_id];
-      if (!tensor_arg->checked_type_.defined()) {
+      const auto* tensor_type = tensor_arg->checked_type_.as<TensorTypeNode>();
+      if (!tensor_type) {
         continue;
       }
-      TensorType tensor_type = GetRef<TensorType>(tensor_arg->type_as<TensorTypeNode>());
-      new_args.Set(i, Var(arg_name, tensor_type));
+      new_args.Set(i, Var(arg_name, GetRef<TensorType>(tensor_type)));

Review comment:
       nit : same here we dont need to extract the node just to covert it back to Ref




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org