You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mb...@apache.org on 2021/01/28 22:06:24 UTC

[tvm] branch main updated: [Relay] Type Relation Fixes (#7362)

This is an automated email from the ASF dual-hosted git repository.

mbrookhart pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b8ad146  [Relay] Type Relation Fixes (#7362)
b8ad146 is described below

commit b8ad146dfd00710376e9477dd2367cc94399d9bb
Author: Matthew Brookhart <mb...@octoml.ai>
AuthorDate: Thu Jan 28 15:06:12 2021 -0700

    [Relay] Type Relation Fixes (#7362)
    
    * fix an error in the dynamic Full Type Relation
    
    * Add Diagnostic Errors to Broadcast Type Relations
---
 src/relay/op/dyn/tensor/transform.cc |  3 +++
 src/relay/op/type_relations.cc       | 12 ++++++++++--
 2 files changed, 13 insertions(+), 2 deletions(-)

diff --git a/src/relay/op/dyn/tensor/transform.cc b/src/relay/op/dyn/tensor/transform.cc
index e4e81e3..8bad394 100644
--- a/src/relay/op/dyn/tensor/transform.cc
+++ b/src/relay/op/dyn/tensor/transform.cc
@@ -400,6 +400,9 @@ bool FullRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   if (fill_value == nullptr) {
     return false;
   }
+  if (fill_shape == nullptr) {
+    return false;
+  }
 
   DataType out_dtype = param->dtype;
   if (out_dtype.bits() == 0) {
diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc
index 7a3bfcb..7b30aea 100644
--- a/src/relay/op/type_relations.cc
+++ b/src/relay/op/type_relations.cc
@@ -104,7 +104,11 @@ bool BroadcastRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   //                 << ",Out:" << types[2] << std::endl;
   if (auto* t0 = types[0].as<TensorTypeNode>()) {
     if (auto* t1 = types[1].as<TensorTypeNode>()) {
-      ICHECK_EQ(t0->dtype, t1->dtype);
+      if (t0->dtype != t1->dtype) {
+        reporter->GetDiagCtx().Emit(Diagnostic::Error(t0->span)
+                                    << "data types " << t0->dtype << " and " << t1->dtype
+                                    << "do not match in BroadcastRel");
+      }
       reporter->Assign(
           types[2], ConcreteBroadcast(GetRef<TensorType>(t0), GetRef<TensorType>(t1), t0->dtype));
       return true;
@@ -120,7 +124,11 @@ bool BroadcastCompRel(const Array<Type>& types, int num_inputs, const Attrs& att
   //                 << ",Out:" << types[2] << std::endl;
   if (auto* t0 = types[0].as<TensorTypeNode>()) {
     if (auto* t1 = types[1].as<TensorTypeNode>()) {
-      ICHECK_EQ(t0->dtype, t1->dtype);
+      if (t0->dtype != t1->dtype) {
+        reporter->GetDiagCtx().Emit(Diagnostic::Error(t0->span)
+                                    << "data types " << t0->dtype << " and " << t1->dtype
+                                    << "do not match in BroadcastCompRel");
+      }
       reporter->Assign(types[2], ConcreteBroadcast(GetRef<TensorType>(t0), GetRef<TensorType>(t1),
                                                    DataType::Bool()));
       return true;