You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by co...@apache.org on 2022/05/05 15:34:38 UTC
[tvm] branch main updated: Fix mixed precision output type to original type (#11142)
This is an automated email from the ASF dual-hosted git repository.
comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new eae836cdf6 Fix mixed precision output type to original type (#11142)
eae836cdf6 is described below
commit eae836cdf66f54f1e81e78e48bfa051431e8556f
Author: Gayatri P K <qu...@quicinc.com>
AuthorDate: Thu May 5 21:04:30 2022 +0530
Fix mixed precision output type to original type (#11142)
---
src/relay/transforms/to_mixed_precision.cc | 60 +++++++++++++++++++++++----
tests/python/relay/test_to_mixed_precision.py | 39 ++++++++++++-----
2 files changed, 82 insertions(+), 17 deletions(-)
diff --git a/src/relay/transforms/to_mixed_precision.cc b/src/relay/transforms/to_mixed_precision.cc
index 4ad3482f74..e1d3a264c2 100644
--- a/src/relay/transforms/to_mixed_precision.cc
+++ b/src/relay/transforms/to_mixed_precision.cc
@@ -36,6 +36,7 @@
namespace tvm {
namespace relay {
+TVM_REGISTER_PASS_CONFIG_OPTION("relay.ToMixedPrecision.keep_orig_output_dtype", Bool);
// A callable which hashes std::pair
struct pair_hash {
template <class T1, class T2>
@@ -105,6 +106,9 @@ class MixedPrecisionPass : public MixedModeMutator {
* encountered. Used for emitting warnings on missing ops in the pass.
*/
std::unordered_map<std::string, int> missing_ops_;
+ const RelayExprNode* root_;
+ std::vector<DataType> original_dtype_;
+ bool keep_orig_output_dtype_;
Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
/* If the accumulation dtype is in the attributes make a copy and mutate the field. */
@@ -278,8 +282,23 @@ class MixedPrecisionPass : public MixedModeMutator {
public:
using MixedModeMutator::VisitExpr_;
- explicit MixedPrecisionPass(DataType mixed_precision_type = DataType::Float(16))
- : MixedModeMutator(), mixed_precision_type_(mixed_precision_type) {
+ explicit MixedPrecisionPass(Expr base, bool keep_orig_output_dtype,
+ DataType mixed_precision_type = DataType::Float(16))
+ : MixedModeMutator(),
+ mixed_precision_type_(mixed_precision_type),
+ root_(Downcast<Function>(base)->body.get()),
+ keep_orig_output_dtype_(keep_orig_output_dtype) {
+ if (keep_orig_output_dtype_) {
+ if (root_->IsInstance<tvm::relay::TupleNode>()) {
+ const TupleTypeNode* tuple_type = (root_->checked_type_).as<TupleTypeNode>();
+ for (Type t : tuple_type->fields) {
+ const TensorTypeNode* tensor_type = t.as<TensorTypeNode>();
+ original_dtype_.push_back(tensor_type->dtype);
+ }
+ } else if (root_->IsInstance<tvm::relay::CallNode>()) {
+ original_dtype_.push_back((root_->checked_type_).as<TensorTypeNode>()->dtype);
+ }
+ }
if (!mixed_precision_type_.is_float() && !mixed_precision_type_.is_bfloat16()) {
LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat16, but got "
<< mixed_precision_type_;
@@ -381,6 +400,11 @@ class MixedPrecisionPass : public MixedModeMutator {
if (accumulation_dtype != output_dtype) {
output = CastArg(output, GetType(output), output_dtype);
}
+ if (pre_call_node == root_ && keep_orig_output_dtype_) {
+ if (original_dtype_[0] != output_dtype) {
+ output = CastArg(output, GetType(output), original_dtype_[0]);
+ }
+ }
return output;
}
@@ -396,6 +420,21 @@ class MixedPrecisionPass : public MixedModeMutator {
Expr Rewrite_(const TupleNode* pre, const Expr& post) {
// The old checked type in the expression may not be valid so clear it
post->checked_type_ = Type(nullptr);
+ if (pre == root_ && keep_orig_output_dtype_) {
+ Array<Expr> new_expr;
+ bool all_same = true;
+ for (size_t i = 0; i < original_dtype_.size(); i++) {
+ Expr output_element = GetField(post, i);
+ Expr casted_element;
+ auto output_element_type = transform::InferTypeLocal(output_element);
+ casted_element = CastArg(output_element, output_element_type, original_dtype_[i]);
+ new_expr.push_back(casted_element);
+ all_same &= casted_element.same_as(output_element);
+ }
+ if (!all_same) {
+ return Tuple(new_expr);
+ }
+ }
return post;
}
@@ -421,11 +460,12 @@ class MixedPrecisionPass : public MixedModeMutator {
}
// To access map of ops not registered for error reporting
- friend Expr ToMixedPrecision(const Expr& expr, const DataType& mixed_precision_type,
- int missing_op_mode);
+ friend Expr ToMixedPrecision(const Expr& expr, bool keep_orig_output_dtype,
+ const DataType& mixed_precision_type, int missing_op_mode);
};
-Expr ToMixedPrecision(const Expr& expr, const DataType& mixed_precision_type, int missing_op_mode) {
+Expr ToMixedPrecision(const Expr& expr, bool keep_orig_output_dtype,
+ const DataType& mixed_precision_type, int missing_op_mode) {
/*
missing_op_mode:
@@ -436,7 +476,8 @@ Expr ToMixedPrecision(const Expr& expr, const DataType& mixed_precision_type, in
ICHECK(missing_op_mode >= 0 && missing_op_mode <= 2)
<< " missing_op_mode must be either 0, 1, or 2 got " << missing_op_mode;
- MixedPrecisionPass converter = MixedPrecisionPass(mixed_precision_type);
+ MixedPrecisionPass converter =
+ MixedPrecisionPass(expr, keep_orig_output_dtype, mixed_precision_type);
auto result = converter.Mutate(expr);
for (auto it = converter.missing_ops_.begin();
@@ -460,7 +501,12 @@ namespace transform {
Pass ToMixedPrecision(DataType mixed_precision_type, int missing_op_mode) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
- return Downcast<Function>(ToMixedPrecision(f, mixed_precision_type, missing_op_mode));
+ bool keep_orig_output_dtype = false;
+ keep_orig_output_dtype = pc->GetConfig("relay.ToMixedPrecision.keep_orig_output_dtype",
+ Bool(keep_orig_output_dtype))
+ .value();
+ return Downcast<Function>(
+ ToMixedPrecision(f, keep_orig_output_dtype, mixed_precision_type, missing_op_mode));
};
return CreateFunctionPass(pass_func, 0, "ToMixedPrecision", {});
}
diff --git a/tests/python/relay/test_to_mixed_precision.py b/tests/python/relay/test_to_mixed_precision.py
index 2afd6ff247..026b458bde 100644
--- a/tests/python/relay/test_to_mixed_precision.py
+++ b/tests/python/relay/test_to_mixed_precision.py
@@ -41,17 +41,31 @@ def verify_mixed_precision_output_close(
mixed_precision_dtype="float16",
rtol: float = 1e-3,
atol: float = 0,
+ keep_orig_output_dtype=False,
) -> tvm.runtime.Module:
mod = InferType()(mod)
result_fp32 = run_module(mod, mod_params)
- fp16_mod = ToMixedPrecision(mixed_precision_dtype)(mod)
- result_fp16 = run_module(fp16_mod, mod_params)
+
+ if not keep_orig_output_dtype:
+ fp16_mod = ToMixedPrecision(mixed_precision_dtype)(mod)
+ result_fp16 = run_module(fp16_mod, mod_params)
+ else:
+ with tvm.transform.PassContext(
+ config={"relay.ToMixedPrecision.keep_orig_output_dtype": True}
+ ):
+ fp16_mod = ToMixedPrecision(mixed_precision_dtype)(mod)
+ result_fp16 = run_module(fp16_mod, mod_params)
# Ensure the results are close
for fp32, fp16 in zip(result_fp32, result_fp16):
np.testing.assert_allclose(fp32, fp16, rtol=rtol, atol=atol)
+ if keep_orig_output_dtype:
+ assert (
+ np.array(result_fp16).dtype == np.array(result_fp32).dtype
+ ), "output type and original type mismatch"
+
return fp16_mod
@@ -117,16 +131,21 @@ def test_convert_single_conv():
"data": np.random.uniform(-1, 1, size=data_shape).astype("float32"),
"weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"),
}
- fp16_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.01, rtol=1e-3)
+ fp16_mod = verify_mixed_precision_output_close(
+ mod, mod_params, atol=0.01, rtol=1e-3, keep_orig_output_dtype=True
+ )
expected_mod = tvm.IRModule.from_expr(
- relay.nn.conv2d(
- relay.cast(data, "float16"),
- relay.cast(weight, "float16"),
- strides=(1, 1),
- padding=(1, 1),
- out_dtype="float16",
- ),
+ relay.cast(
+ relay.nn.conv2d(
+ relay.cast(data, "float16"),
+ relay.cast(weight, "float16"),
+ strides=(1, 1),
+ padding=(1, 1),
+ out_dtype="float16",
+ ),
+ "float32",
+ )
)
expected_mod = tvm.relay.transform.InferType()(expected_mod)