You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2024/03/15 08:05:59 UTC

(tvm) branch main updated: [TIR] Implement max/min_value for fp8 data types (#16723)

This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9ec72494cf [TIR] Implement max/min_value for fp8 data types (#16723)
9ec72494cf is described below

commit 9ec72494cf71a6a6c6a94d29e33c986cbfaaf5fc
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Fri Mar 15 01:05:53 2024 -0700

    [TIR] Implement max/min_value for fp8 data types (#16723)
---
 src/tir/op/op.cc | 12 ++++++++++++
 1 file changed, 12 insertions(+)

diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index c46a8c2643..7f47e66062 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -262,6 +262,12 @@ PrimExpr max_value(const DataType& dtype, Span span) {
     }
   } else if (dtype.is_bfloat16()) {
     return FloatImm(dtype, std::numeric_limits<float>::max(), span);
+  } else if (dtype.is_float8()) {
+    if (dtype.code() == DataType::TypeCode::kE5M2Float) {
+      return FloatImm(dtype, 57344.0, span);
+    } else if (dtype.code() == DataType::TypeCode::kE4M3Float) {
+      return FloatImm(dtype, 448.0, span);
+    }
   }
   LOG(FATAL) << "Cannot decide max_value for type" << dtype;
 }
@@ -296,6 +302,12 @@ PrimExpr min_value(const DataType& dtype, Span span) {
     }
   } else if (dtype.is_bfloat16()) {
     return FloatImm(dtype, std::numeric_limits<float>::lowest(), span);
+  } else if (dtype.is_float8()) {
+    if (dtype.code() == DataType::TypeCode::kE5M2Float) {
+      return FloatImm(dtype, -57344.0, span);
+    } else if (dtype.code() == DataType::TypeCode::kE4M3Float) {
+      return FloatImm(dtype, -448.0, span);
+    }
   }
   LOG(FATAL) << "Cannot decide min_value for type" << dtype;
 }