You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/11/03 20:42:13 UTC
[tvm] branch main updated: [Relay] Add tensor rank check for `nn.instance_norm` (#13280)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 90ed632280 [Relay] Add tensor rank check for `nn.instance_norm` (#13280)
90ed632280 is described below
commit 90ed632280898dafeac40913d006993bc71a8409
Author: WANG Zihan <wz...@126.com>
AuthorDate: Fri Nov 4 04:42:07 2022 +0800
[Relay] Add tensor rank check for `nn.instance_norm` (#13280)
Add tensor rank check for `nn.instance_norm`.
---
src/relay/op/nn/nn.cc | 1 +
1 file changed, 1 insertion(+)
diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc
index 8644957b1c..9e2fe63b00 100644
--- a/src/relay/op/nn/nn.cc
+++ b/src/relay/op/nn/nn.cc
@@ -923,6 +923,7 @@ bool InstanceNormRel(const Array<Type>& types, int num_inputs, const Attrs& attr
ICHECK_EQ(types.size(), 4);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
+ ICHECK_GT(data->shape.size(), 2);
const InstanceNormAttrs* param = attrs.as<InstanceNormAttrs>();
int axis = param->axis >= 0 ? param->axis : param->axis + data->shape.size();
ICHECK(axis >= 0 && axis < (int)data->shape.size());