You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/06/15 07:10:12 UTC

[GitHub] [tvm] wzh99 opened a new pull request, #11728: [Relay] Implement `SoftmaxRel` for softmax operators.

wzh99 opened a new pull request, #11728:
URL: https://github.com/apache/tvm/pull/11728

   This PR fixes #11684. I replace `IdentityRel` in `nn.softmax`, `nn.fast_softmax` and `nn.log_softmax` with a newly implemented `SoftmaxRel` so that the attribute `axis` is be checked during type inference. For the test case shown in #11684, the following error is reported:
   
   ```
   The axis is not in range [-1, 1)
   Traceback (most recent call last):
     File "/Users/wzh/tvm-bug/bug_softmax_axis.py", line 8, in <module>
       mod = relay.transform.InferType()(mod)
     File "/Users/wzh/tvm-dev/python/tvm/ir/transform.py", line 161, in __call__
       return _ffi_transform_api.RunPass(self, mod)
     File "/Users/wzh/tvm-dev/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
       raise get_last_ffi_error()
   tvm._ffi.base.TVMError: Traceback (most recent call last):
     [bt] (8) 9   libtvm.dylib                        0x0000000119ef03b4 tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<void tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::$_6>(tvm::transform::$_6, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >)::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) + 948
     [bt] (7) 8   libtvm.dylib                        0x0000000119ee5964 tvm::transform::Pass::operator()(tvm::IRModule) const + 148
     [bt] (6) 7   libtvm.dylib                        0x0000000119ee5d71 tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const + 753
     [bt] (5) 6   libtvm.dylib                        0x0000000119ee6873 tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const + 819
     [bt] (4) 5   libtvm.dylib                        0x000000011b21ddfd tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<void tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::$_2>(tvm::relay::transform::InferType()::$_2)::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) + 1933
     [bt] (3) 4   libtvm.dylib                        0x000000011b20d217 tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function) + 135
     [bt] (2) 3   libtvm.dylib                        0x000000011afd2a2f tvm::relay::TypeSolver::Solve() + 1615
     [bt] (1) 2   libtvm.dylib                        0x0000000119b86699 tvm::runtime::detail::LogFatal::Entry::Finalize() + 89
     [bt] (0) 1   libtvm.dylib                        0x000000011b5a3508 tvm::runtime::Backtrace() + 24
     [bt] (8) 9   libtvm.dylib                        0x000000011b20d217 tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function) + 135
     [bt] (7) 8   libtvm.dylib                        0x000000011afd285c tvm::relay::TypeSolver::Solve() + 1148
     [bt] (6) 7   libtvm.dylib                        0x000000011afd2dd0 tvm::TypedEnvFunc<bool (tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>::operator()(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&) const + 416
     [bt] (5) 6   libtvm.dylib                        0x000000011a08b154 tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<void tvm::runtime::TypedPackedFunc<bool (tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>(bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&))::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) + 20
     [bt] (4) 5   libtvm.dylib                        0x000000011a08b563 void tvm::runtime::TypedPackedFunc<bool (tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>(bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&))::'lambda'(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const + 1027
     [bt] (3) 4   libtvm.dylib                        0x000000011acd163e tvm::relay::SoftmaxRel(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&) + 942
     [bt] (2) 3   libtvm.dylib                        0x0000000119e6a08b tvm::DiagnosticContext::Render() + 459
     [bt] (1) 2   libtvm.dylib                        0x0000000119b86699 tvm::runtime::detail::LogFatal::Entry::Finalize() + 89
     [bt] (0) 1   libtvm.dylib                        0x000000011b5a3508 tvm::runtime::Backtrace() + 24
     File "/Users/wzh/tvm-dev/src/relay/analysis/type_solver.cc", line 624
   TVMError: 
   ---------------------------------------------------------------
   An error occurred during the execution of TVM.
   For more information, please see: https://tvm.apache.org/docs/errors.html
   ---------------------------------------------------------------
     Check failed: (false) is false: [15:01:35] /Users/wzh/tvm-dev/src/ir/diagnostic.cc:105: DiagnosticError: one or more error diagnostics were emitted, please check diagnostic render for output.
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] ganler commented on a diff in pull request #11728: [Relay] Implement `SoftmaxRel` for softmax operators.

Posted by GitBox <gi...@apache.org>.
ganler commented on code in PR #11728:
URL: https://github.com/apache/tvm/pull/11728#discussion_r897705134


##########
src/relay/op/nn/nn.cc:
##########
@@ -399,6 +399,27 @@ where :math:`*` is an channelwise multiplication for each sample in the batch.
 // relay.softmax
 TVM_REGISTER_NODE_TYPE(SoftmaxAttrs);
 
+bool SoftmaxRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                const TypeReporter& reporter) {
+  ICHECK_EQ(types.size(), 2);
+  const auto* data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) return false;
+
+  const SoftmaxAttrs* param = attrs.as<SoftmaxAttrs>();
+  ICHECK(param != nullptr);
+  int axis = param->axis;
+  int ndim = static_cast<int>(data->shape.size());
+  if (axis >= ndim || axis < -ndim) {
+    reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
+                                     << "The axis is not in range [" << -ndim << ", " << ndim

Review Comment:
   A better err msg should indicate the wrong axis. For example, 
   
   ```c++
   ...
   << "Wrong axis (" << axis << ") not in expected range: [" << -ndim << ", " << ndim << ")";
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] masahi merged pull request #11728: [Relay] Implement `SoftmaxRel` for softmax operators.

Posted by GitBox <gi...@apache.org>.
masahi merged PR #11728:
URL: https://github.com/apache/tvm/pull/11728


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] ganler commented on a diff in pull request #11728: [Relay] Implement `SoftmaxRel` for softmax operators.

Posted by GitBox <gi...@apache.org>.
ganler commented on code in PR #11728:
URL: https://github.com/apache/tvm/pull/11728#discussion_r898283093


##########
src/relay/op/nn/nn.cc:
##########
@@ -399,6 +399,27 @@ where :math:`*` is an channelwise multiplication for each sample in the batch.
 // relay.softmax
 TVM_REGISTER_NODE_TYPE(SoftmaxAttrs);
 
+bool SoftmaxRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                const TypeReporter& reporter) {
+  ICHECK_EQ(types.size(), 2);
+  const auto* data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) return false;
+
+  const SoftmaxAttrs* param = attrs.as<SoftmaxAttrs>();
+  ICHECK(param != nullptr);
+  int axis = param->axis;
+  int ndim = static_cast<int>(data->shape.size());
+  if (axis >= ndim || axis < -ndim) {
+    reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
+                                     << "The axis is not in range [" << -ndim << ", " << ndim

Review Comment:
   It seems the value of variable 'axis' still has not been printed?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] wzh99 commented on a diff in pull request #11728: [Relay] Implement `SoftmaxRel` for softmax operators.

Posted by GitBox <gi...@apache.org>.
wzh99 commented on code in PR #11728:
URL: https://github.com/apache/tvm/pull/11728#discussion_r897888041


##########
src/relay/op/nn/nn.cc:
##########
@@ -399,6 +399,27 @@ where :math:`*` is an channelwise multiplication for each sample in the batch.
 // relay.softmax
 TVM_REGISTER_NODE_TYPE(SoftmaxAttrs);
 
+bool SoftmaxRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                const TypeReporter& reporter) {
+  ICHECK_EQ(types.size(), 2);
+  const auto* data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) return false;
+
+  const SoftmaxAttrs* param = attrs.as<SoftmaxAttrs>();
+  ICHECK(param != nullptr);
+  int axis = param->axis;
+  int ndim = static_cast<int>(data->shape.size());
+  if (axis >= ndim || axis < -ndim) {
+    reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
+                                     << "The axis is not in range [" << -ndim << ", " << ndim

Review Comment:
   I have modified the error message as you suggest.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] ganler commented on pull request #11728: [Relay] Implement `SoftmaxRel` for softmax operators.

Posted by GitBox <gi...@apache.org>.
ganler commented on PR #11728:
URL: https://github.com/apache/tvm/pull/11728#issuecomment-1157014075

   Basically this PR lets Relay reject invalid softmax operator (axis >= rank) as earlier as the type inference phase (though such invalid cases will be rejected anyhow in later checks). @masahi Can you help decide if we want to merge this improvement? Thanks!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org