You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/11/03 20:42:13 UTC

[tvm] branch main updated: [Relay] Add tensor rank check for `nn.instance_norm` (#13280)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 90ed632280 [Relay] Add tensor rank check for `nn.instance_norm` (#13280)
90ed632280 is described below

commit 90ed632280898dafeac40913d006993bc71a8409
Author: WANG Zihan <wz...@126.com>
AuthorDate: Fri Nov 4 04:42:07 2022 +0800

    [Relay] Add tensor rank check for `nn.instance_norm` (#13280)
    
    Add tensor rank check for `nn.instance_norm`.
---
 src/relay/op/nn/nn.cc | 1 +
 1 file changed, 1 insertion(+)

diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc
index 8644957b1c..9e2fe63b00 100644
--- a/src/relay/op/nn/nn.cc
+++ b/src/relay/op/nn/nn.cc
@@ -923,6 +923,7 @@ bool InstanceNormRel(const Array<Type>& types, int num_inputs, const Attrs& attr
   ICHECK_EQ(types.size(), 4);
   const auto* data = types[0].as<TensorTypeNode>();
   if (data == nullptr) return false;
+  ICHECK_GT(data->shape.size(), 2);
   const InstanceNormAttrs* param = attrs.as<InstanceNormAttrs>();
   int axis = param->axis >= 0 ? param->axis : param->axis + data->shape.size();
   ICHECK(axis >= 0 && axis < (int)data->shape.size());