You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mb...@apache.org on 2021/01/28 22:06:24 UTC
[tvm] branch main updated: [Relay] Type Relation Fixes (#7362)
This is an automated email from the ASF dual-hosted git repository.
mbrookhart pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b8ad146 [Relay] Type Relation Fixes (#7362)
b8ad146 is described below
commit b8ad146dfd00710376e9477dd2367cc94399d9bb
Author: Matthew Brookhart <mb...@octoml.ai>
AuthorDate: Thu Jan 28 15:06:12 2021 -0700
[Relay] Type Relation Fixes (#7362)
* fix an error in the dynamic Full Type Relation
* Add Diagnostic Errors to Broadcast Type Relations
---
src/relay/op/dyn/tensor/transform.cc | 3 +++
src/relay/op/type_relations.cc | 12 ++++++++++--
2 files changed, 13 insertions(+), 2 deletions(-)
diff --git a/src/relay/op/dyn/tensor/transform.cc b/src/relay/op/dyn/tensor/transform.cc
index e4e81e3..8bad394 100644
--- a/src/relay/op/dyn/tensor/transform.cc
+++ b/src/relay/op/dyn/tensor/transform.cc
@@ -400,6 +400,9 @@ bool FullRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
if (fill_value == nullptr) {
return false;
}
+ if (fill_shape == nullptr) {
+ return false;
+ }
DataType out_dtype = param->dtype;
if (out_dtype.bits() == 0) {
diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc
index 7a3bfcb..7b30aea 100644
--- a/src/relay/op/type_relations.cc
+++ b/src/relay/op/type_relations.cc
@@ -104,7 +104,11 @@ bool BroadcastRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
// << ",Out:" << types[2] << std::endl;
if (auto* t0 = types[0].as<TensorTypeNode>()) {
if (auto* t1 = types[1].as<TensorTypeNode>()) {
- ICHECK_EQ(t0->dtype, t1->dtype);
+ if (t0->dtype != t1->dtype) {
+ reporter->GetDiagCtx().Emit(Diagnostic::Error(t0->span)
+ << "data types " << t0->dtype << " and " << t1->dtype
+ << "do not match in BroadcastRel");
+ }
reporter->Assign(
types[2], ConcreteBroadcast(GetRef<TensorType>(t0), GetRef<TensorType>(t1), t0->dtype));
return true;
@@ -120,7 +124,11 @@ bool BroadcastCompRel(const Array<Type>& types, int num_inputs, const Attrs& att
// << ",Out:" << types[2] << std::endl;
if (auto* t0 = types[0].as<TensorTypeNode>()) {
if (auto* t1 = types[1].as<TensorTypeNode>()) {
- ICHECK_EQ(t0->dtype, t1->dtype);
+ if (t0->dtype != t1->dtype) {
+ reporter->GetDiagCtx().Emit(Diagnostic::Error(t0->span)
+ << "data types " << t0->dtype << " and " << t1->dtype
+ << "do not match in BroadcastCompRel");
+ }
reporter->Assign(types[2], ConcreteBroadcast(GetRef<TensorType>(t0), GetRef<TensorType>(t1),
DataType::Bool()));
return true;