You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/02/07 12:36:17 UTC

[GitHub] [tvm] lhutton1 commented on a change in pull request #10060: [microNPU] Refactor type inference data type checks

lhutton1 commented on a change in pull request #10060:
URL: https://github.com/apache/tvm/pull/10060#discussion_r800611881



##########
File path: src/relay/op/contrib/ethosu/common.cc
##########
@@ -75,6 +78,56 @@ Array<IndexExpr> EthosuInferKernelOutput(Array<IndexExpr> ifm_shape, String ifm_
   return output_shape;
 }
 
+DataType DataTypeFromString(const String& dtype) {
+  DLDataType dl_dtype = tvm::runtime::String2DLDataType(dtype);
+  return DataType(dl_dtype);
+}
+
+void CheckDataType(const TypeReporter& reporter, const DataType& data_type,
+                   const std::unordered_set<DataType>& allowed_data_types,
+                   const String& operator_name, const String& tensor_name,
+                   const String& operator_type) {
+  if (allowed_data_types.find(data_type) != allowed_data_types.end()) {
+    return;
+  }
+
+  std::ostringstream message;
+  message << "Invalid operator: expected " << operator_name << " ";
+  if (operator_type != "") {
+    message << operator_type << " ";
+  }
+  message << "to have type in {";
+  for (auto it = allowed_data_types.begin(); it != allowed_data_types.end(); ++it) {
+    message << *it;
+    if (std::next(it) != allowed_data_types.end()) {
+      message << ", ";
+    }
+  }
+  message << "}";
+  message << " for " << tensor_name << " but was " << data_type << ".";
+
+  reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) << message.str());
+}
+
+void CheckDataTypeMatch(const TypeReporter& reporter, const DataType& data_type,
+                        const DataType& data_type2, const String& operator_name,
+                        const String& tensor_name, const String& tensor_name2,
+                        const String& operator_type) {
+  if (data_type == data_type2) {
+    return;
+  }
+
+  std::ostringstream message;
+  message << "Invalid operator: expected " << operator_name << " ";
+  if (operator_type != " ") {

Review comment:
       Good spot, thanks!




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org