You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by co...@apache.org on 2022/05/05 15:34:38 UTC

[tvm] branch main updated: Fix mixed precision output type to original type (#11142)

This is an automated email from the ASF dual-hosted git repository.

comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new eae836cdf6 Fix mixed precision output type to original type (#11142)
eae836cdf6 is described below

commit eae836cdf66f54f1e81e78e48bfa051431e8556f
Author: Gayatri P K <qu...@quicinc.com>
AuthorDate: Thu May 5 21:04:30 2022 +0530

    Fix mixed precision output type to original type (#11142)
---
 src/relay/transforms/to_mixed_precision.cc    | 60 +++++++++++++++++++++++----
 tests/python/relay/test_to_mixed_precision.py | 39 ++++++++++++-----
 2 files changed, 82 insertions(+), 17 deletions(-)

diff --git a/src/relay/transforms/to_mixed_precision.cc b/src/relay/transforms/to_mixed_precision.cc
index 4ad3482f74..e1d3a264c2 100644
--- a/src/relay/transforms/to_mixed_precision.cc
+++ b/src/relay/transforms/to_mixed_precision.cc
@@ -36,6 +36,7 @@
 namespace tvm {
 namespace relay {
 
+TVM_REGISTER_PASS_CONFIG_OPTION("relay.ToMixedPrecision.keep_orig_output_dtype", Bool);
 // A callable which hashes std::pair
 struct pair_hash {
   template <class T1, class T2>
@@ -105,6 +106,9 @@ class MixedPrecisionPass : public MixedModeMutator {
    * encountered. Used for emitting warnings on missing ops in the pass.
    */
   std::unordered_map<std::string, int> missing_ops_;
+  const RelayExprNode* root_;
+  std::vector<DataType> original_dtype_;
+  bool keep_orig_output_dtype_;
 
   Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
     /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
@@ -278,8 +282,23 @@ class MixedPrecisionPass : public MixedModeMutator {
  public:
   using MixedModeMutator::VisitExpr_;
 
-  explicit MixedPrecisionPass(DataType mixed_precision_type = DataType::Float(16))
-      : MixedModeMutator(), mixed_precision_type_(mixed_precision_type) {
+  explicit MixedPrecisionPass(Expr base, bool keep_orig_output_dtype,
+                              DataType mixed_precision_type = DataType::Float(16))
+      : MixedModeMutator(),
+        mixed_precision_type_(mixed_precision_type),
+        root_(Downcast<Function>(base)->body.get()),
+        keep_orig_output_dtype_(keep_orig_output_dtype) {
+    if (keep_orig_output_dtype_) {
+      if (root_->IsInstance<tvm::relay::TupleNode>()) {
+        const TupleTypeNode* tuple_type = (root_->checked_type_).as<TupleTypeNode>();
+        for (Type t : tuple_type->fields) {
+          const TensorTypeNode* tensor_type = t.as<TensorTypeNode>();
+          original_dtype_.push_back(tensor_type->dtype);
+        }
+      } else if (root_->IsInstance<tvm::relay::CallNode>()) {
+        original_dtype_.push_back((root_->checked_type_).as<TensorTypeNode>()->dtype);
+      }
+    }
     if (!mixed_precision_type_.is_float() && !mixed_precision_type_.is_bfloat16()) {
       LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat16, but got "
                  << mixed_precision_type_;
@@ -381,6 +400,11 @@ class MixedPrecisionPass : public MixedModeMutator {
       if (accumulation_dtype != output_dtype) {
         output = CastArg(output, GetType(output), output_dtype);
       }
+      if (pre_call_node == root_ && keep_orig_output_dtype_) {
+        if (original_dtype_[0] != output_dtype) {
+          output = CastArg(output, GetType(output), original_dtype_[0]);
+        }
+      }
       return output;
     }
 
@@ -396,6 +420,21 @@ class MixedPrecisionPass : public MixedModeMutator {
   Expr Rewrite_(const TupleNode* pre, const Expr& post) {
     // The old checked type in the expression may not be valid so clear it
     post->checked_type_ = Type(nullptr);
+    if (pre == root_ && keep_orig_output_dtype_) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < original_dtype_.size(); i++) {
+        Expr output_element = GetField(post, i);
+        Expr casted_element;
+        auto output_element_type = transform::InferTypeLocal(output_element);
+        casted_element = CastArg(output_element, output_element_type, original_dtype_[i]);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(output_element);
+      }
+      if (!all_same) {
+        return Tuple(new_expr);
+      }
+    }
     return post;
   }
 
@@ -421,11 +460,12 @@ class MixedPrecisionPass : public MixedModeMutator {
   }
 
   // To access map of ops not registered for error reporting
-  friend Expr ToMixedPrecision(const Expr& expr, const DataType& mixed_precision_type,
-                               int missing_op_mode);
+  friend Expr ToMixedPrecision(const Expr& expr, bool keep_orig_output_dtype,
+                               const DataType& mixed_precision_type, int missing_op_mode);
 };
 
-Expr ToMixedPrecision(const Expr& expr, const DataType& mixed_precision_type, int missing_op_mode) {
+Expr ToMixedPrecision(const Expr& expr, bool keep_orig_output_dtype,
+                      const DataType& mixed_precision_type, int missing_op_mode) {
   /*
   missing_op_mode:
 
@@ -436,7 +476,8 @@ Expr ToMixedPrecision(const Expr& expr, const DataType& mixed_precision_type, in
   ICHECK(missing_op_mode >= 0 && missing_op_mode <= 2)
       << " missing_op_mode must be either 0, 1, or 2 got " << missing_op_mode;
 
-  MixedPrecisionPass converter = MixedPrecisionPass(mixed_precision_type);
+  MixedPrecisionPass converter =
+      MixedPrecisionPass(expr, keep_orig_output_dtype, mixed_precision_type);
   auto result = converter.Mutate(expr);
 
   for (auto it = converter.missing_ops_.begin();
@@ -460,7 +501,12 @@ namespace transform {
 Pass ToMixedPrecision(DataType mixed_precision_type, int missing_op_mode) {
   runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
       [=](Function f, IRModule m, PassContext pc) {
-        return Downcast<Function>(ToMixedPrecision(f, mixed_precision_type, missing_op_mode));
+        bool keep_orig_output_dtype = false;
+        keep_orig_output_dtype = pc->GetConfig("relay.ToMixedPrecision.keep_orig_output_dtype",
+                                               Bool(keep_orig_output_dtype))
+                                     .value();
+        return Downcast<Function>(
+            ToMixedPrecision(f, keep_orig_output_dtype, mixed_precision_type, missing_op_mode));
       };
   return CreateFunctionPass(pass_func, 0, "ToMixedPrecision", {});
 }
diff --git a/tests/python/relay/test_to_mixed_precision.py b/tests/python/relay/test_to_mixed_precision.py
index 2afd6ff247..026b458bde 100644
--- a/tests/python/relay/test_to_mixed_precision.py
+++ b/tests/python/relay/test_to_mixed_precision.py
@@ -41,17 +41,31 @@ def verify_mixed_precision_output_close(
     mixed_precision_dtype="float16",
     rtol: float = 1e-3,
     atol: float = 0,
+    keep_orig_output_dtype=False,
 ) -> tvm.runtime.Module:
 
     mod = InferType()(mod)
     result_fp32 = run_module(mod, mod_params)
-    fp16_mod = ToMixedPrecision(mixed_precision_dtype)(mod)
-    result_fp16 = run_module(fp16_mod, mod_params)
+
+    if not keep_orig_output_dtype:
+        fp16_mod = ToMixedPrecision(mixed_precision_dtype)(mod)
+        result_fp16 = run_module(fp16_mod, mod_params)
+    else:
+        with tvm.transform.PassContext(
+            config={"relay.ToMixedPrecision.keep_orig_output_dtype": True}
+        ):
+            fp16_mod = ToMixedPrecision(mixed_precision_dtype)(mod)
+            result_fp16 = run_module(fp16_mod, mod_params)
 
     # Ensure the results are close
     for fp32, fp16 in zip(result_fp32, result_fp16):
         np.testing.assert_allclose(fp32, fp16, rtol=rtol, atol=atol)
 
+    if keep_orig_output_dtype:
+        assert (
+            np.array(result_fp16).dtype == np.array(result_fp32).dtype
+        ), "output type and original type mismatch"
+
     return fp16_mod
 
 
@@ -117,16 +131,21 @@ def test_convert_single_conv():
         "data": np.random.uniform(-1, 1, size=data_shape).astype("float32"),
         "weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"),
     }
-    fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=1e-3)
+    fp16_mod = verify_mixed_precision_output_close(
+        mod, mod_params, atol=0.01, rtol=1e-3, keep_orig_output_dtype=True
+    )
 
     expected_mod = tvm.IRModule.from_expr(
-        relay.nn.conv2d(
-            relay.cast(data, "float16"),
-            relay.cast(weight, "float16"),
-            strides=(1, 1),
-            padding=(1, 1),
-            out_dtype="float16",
-        ),
+        relay.cast(
+            relay.nn.conv2d(
+                relay.cast(data, "float16"),
+                relay.cast(weight, "float16"),
+                strides=(1, 1),
+                padding=(1, 1),
+                out_dtype="float16",
+            ),
+            "float32",
+        )
     )
     expected_mod = tvm.relay.transform.InferType()(expected_mod)