You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/05/18 17:41:46 UTC

[GitHub] [tvm] AndrewZhaoLuo opened a new pull request #8069: Add FP16 model conversion pass

AndrewZhaoLuo opened a new pull request #8069:
URL: https://github.com/apache/tvm/pull/8069


   Thanks for contributing to TVM!   Please refer to guideline https://tvm.apache.org/docs/contribute/ for useful information and tips. After the pull request is submitted, please request code reviews from [Reviewers](https://github.com/apache/incubator-tvm/blob/master/CONTRIBUTORS.md#reviewers) by @ them in the pull request thread.
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#issuecomment-859341545






-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on a change in pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r647776502



##########
File path: tests/python/frontend/mxnet/test_forward.py
##########
@@ -1223,6 +1221,8 @@ def verify(shape, axis=1, fix_gamma=False):
 
 @tvm.testing.uses_gpu
 def test_forward_instance_norm():
+    np.random.seed(90)
+

Review comment:
       Oh ok that's an interesting idea. I had a failure where the passing rtol was 1.05e-5 so I'm just going to increase the tolerance.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] CoinCheung commented on pull request #8069: [Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass

Posted by GitBox <gi...@apache.org>.
CoinCheung commented on pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#issuecomment-878677560


   HI,
   
   I tried this:
   ```python
   def compile_model(mod, params, target, logfile, save_path):
       tvm.relay.backend.compile_engine.get().clear()
       mod = tvm.relay.transform.ToMixedPrecision(
               mixed_precision_type='float16')(mod)
       with tvm.autotvm.apply_history_best(logfile):
           with tvm.transform.PassContext(opt_level=3):
               lib = tvm.relay.build(mod, target=target, params=params)
       lib.export_library(save_path) # 保存编译好的模型, 必须so结尾,不然c++不识别
   
   ```
   But I got the error: 
   ```
   Traceback (most recent call last):
     File "main.py", line 207, in <module>
       args.save_path)
     File "main.py", line 122, in compile_model
       mixed_precision_type='float16')(mod)
     File "/root/build/tvm/python/tvm/ir/transform.py", line 161, in __call__
       return _ffi_transform_api.RunPass(self, mod)
     File "/root/build/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
       raise get_last_ffi_error()
   tvm._ffi.base.TVMError: Traceback (most recent call last):
     23: TVMFuncCall
     22: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
     21: tvm::transform::Pass::operator()(tvm::IRModule) const
     20: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
     19: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
     18: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::relay::Function (tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::ToMixedPrecision(tvm::runtime::DataType, int)::{lambda(tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relay::transform::ToMixedPrecision(tvm::runtime::DataType, int)::{lambda(tvm::relay::Function, tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
     17: tvm::relay::ToMixedPrecision(tvm::RelayExpr const&, tvm::runtime::DataType const&, int)
     16: tvm::relay::MixedModeMutator::VisitExpr(tvm::RelayExpr const&)
     15: tvm::relay::MixedModeMutator::VisitLeaf(tvm::RelayExpr const&)
     14: _ZN3tvm5relay16MixedModeMutator17DispatchVisitExprERKNS_9Re
     13: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
     12: tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
     11: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlR
     10: tvm::relay::MixedPrecisionPass::VisitExpr_(tvm::relay::FunctionNode const*)
     9: tvm::relay::ExprMutator::VisitExpr_(tvm::relay::FunctionNode const*)
     8: tvm::relay::MixedModeMutator::VisitExpr(tvm::RelayExpr const&)
     7: tvm::relay::MixedModeMutator::VisitLeaf(tvm::RelayExpr const&)
     6: _ZN3tvm5relay16MixedModeMutator17DispatchVisitExprERKNS_9Re
     5: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
     4: tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
     3: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlR
     2: tvm::relay::MixedModeMutator::VisitExpr_(tvm::relay::CallNode const*)
     1: tvm::relay::MixedPrecisionPass::Rewrite_(tvm::relay::CallNode const*, tvm::RelayExpr const&)
     0: tvm::Op::GetAttrMapContainer(tvm::runtime::String const&)
     File "/root/build/tvm/src/ir/../node/attr_registry.h", line 146
   TVMError: Attribute 'FTVMMixedPrecisionConversionType' is not registered
   ```
   
   Did I miss any key point of using this feature ?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi merged pull request #8069: [Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass

Posted by GitBox <gi...@apache.org>.
masahi merged pull request #8069:
URL: https://github.com/apache/tvm/pull/8069


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo edited a comment on pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo edited a comment on pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#issuecomment-858045479






-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] comaniac commented on pull request #8069: [Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass

Posted by GitBox <gi...@apache.org>.
comaniac commented on pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#issuecomment-861847268


   > LGTM
   
   Sorry I was reviewing another PR and misapproved this one. Please ignore this approval and I'll take another look later.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] comaniac commented on a change in pull request #8069: [Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r652195571



##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,409 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed floating point precision for relay graphs. i.e. turn a graph into fp16.
+ *
+ */
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
+// savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to
+// justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to
+// numerical reasons.
+enum MixedTypeConversionCategory : int {
+  MIXED_PRECISION_ALWAYS = 0,
+  MIXED_PRECISION_FOLLOW = 1,
+  MIXED_PRECISION_NEVER = 2
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// Return array is of type : [MixedTypeConversionCategory (int), String, String]
+// The fields are          : [ConversionCategory, accumulation_datatype, output_datatype]
+// Call is a call node, DataType is the mixed precision type
+using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc<Array<ObjectRef>(
+    const Call& call_node, const std::string& target_dtype_str)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+
+  // The target datatype we want to convert to e.g. FP16
+  const DataType mixed_precision_type;
+
+  // If false, throws a fatal error if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool ignore_missing_ops;
+
+  // If true, emits a warning if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool warn_missing_ops;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs cur_attrs = call->attrs;
+    if (cur_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = cur_attrs.as<Conv1DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv1DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DeformableConv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DenseAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<BatchMatmulAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = cur_attrs.as<InitOpAttrs>()) {
+        return ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return cur_attrs;
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    DataType cur_type = (attrs->out_dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->out_dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    DataType cur_type = (attrs->dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only target mixed precision type elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == mixed_precision_type;
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsMixedPrecisionType(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    if (expr_dtype == wanted_dtype) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;
+
+    // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
+    const ExprNode* new_expr_node = result.as<ExprNode>();
+    cast_nodes_cache[{new_expr_node, expr_dtype}] = expr;
+    return result;
+  }
+
+  Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
+    /* Helper for casting arguments to call_nodes handling all relevant cases. */
+    if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+      return CachedCast(expr, tensor_type->dtype, wanted_dtype);
+    } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
+        Expr tuple_element = GetField(expr, i);
+        Type tuple_element_dtype = (tuple_type->fields)[i];
+        Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(tuple_element);
+      }
+      return all_same ? expr : Tuple(new_expr);
+    } else {
+      LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
+      return expr;
+    }
+  }
+
+  std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args,
+                                                  const Array<Type>& cur_arg_types,
+                                                  const DataType& wanted_dtype) {
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    for (size_t i = 0; i < cur_args.size(); i++) {
+      Expr cur_arg = cur_args[i];
+      Type cur_arg_type = cur_arg_types[i];
+      Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+    }
+    return {new_args, new_arg_types};
+  }
+
+ public:
+  using MixedModeMutator::VisitExpr_;
+
+  explicit MixedPrecisionPass(DataType mixed_precision_type = DataType::Float(16),
+                              bool ignore_missing_ops = true, bool warn_missing_ops = true)
+      : MixedModeMutator(),
+        mixed_precision_type(mixed_precision_type),
+        ignore_missing_ops(ignore_missing_ops),
+        warn_missing_ops(warn_missing_ops) {
+    if (!mixed_precision_type.is_float() && !mixed_precision_type.is_bfloat16())
+      LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat16 got "
+                 << mixed_precision_type;
+  }
+
+  Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final {
+    const CallNode* post_call_node = post.as<CallNode>();
+    if (!post_call_node) {
+      LOG(FATAL) << "Expected a CallNode for the rewrite got " << post;
+    }
+
+    Expr cur_op = post_call_node->op;
+
+    // Get info on the operation being called:
+    // conversion category (int), accumulation dtype (str), output dtype (str)
+    MixedTypeConversionCategory initial_category;
+    DataType accumulation_dtype, output_dtype;
+    if (cur_op.as<FunctionNode>()) {
+      // Avoid messing with functions to avoid changing signature
+      initial_category = MIXED_PRECISION_NEVER;
+      accumulation_dtype = DataType::Float(32);
+      output_dtype = DataType::Float(32);
+    } else if (cur_op.as<OpNode>()) {
+      static auto attr_map =
+          Op::GetAttrMap<FTVMMixedPrecisionConversionType>("FTVMMixedPrecisionConversionType");
+      Op op = Downcast<Op>(cur_op);
+      if (attr_map.count(op)) {
+        // Calculate the conversion category and dtypes from registered attribute.
+        FTVMMixedPrecisionConversionType func = attr_map[op];
+        Array<ObjectRef> op_descriptor =
+            func(GetRef<Call>(pre_call_node), DLDataType2String(mixed_precision_type));
+
+        int64_t op_conversion_type = Downcast<Integer>(op_descriptor[0])->value;
+        initial_category = static_cast<MixedTypeConversionCategory>(op_conversion_type);
+        accumulation_dtype = DataType(String2DLDataType(Downcast<String>(op_descriptor[1])));
+        output_dtype = DataType(String2DLDataType(Downcast<String>(op_descriptor[2])));
+      } else {
+        if (!ignore_missing_ops) LOG(FATAL) << "Op " << op->name << " not in conversion lists!";
+        if (warn_missing_ops) LOG(WARNING) << "Op " << op->name << " not in conversion lists!";

Review comment:
       Yeah but maybe don't call it verbose because `!ignore_mssing_ops` will throw errors and terminate the pass. Maybe the following?
   
   ```
   allow_missing_ops: int
   0: Does not allow any missing ops. Will throw errors and terminate the pass when encountering any. 
   1: Allow missing ops but throw warnings.
   2: Allow missing ops and silently ignore them.
   ```
   
   Either 1 or 2 can be the default mode.
   Meanwhile, whatever the behavior we decided at the end, it would be better to also control the message. For example, the current implementation has 2 problems:
   1. The pass stops when it sees the first missing op.
   2. If a model has **one** missing op but it appears 100 times, then users will see 100 warning messages.
   
   As a result, it would be better to first collect unique missing ops with their appearances (e.g., `(op, times)`), and print out it at the end of this pass according to the user-provided flag/mode.
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbrookhart commented on a change in pull request #8069: [Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r652090449



##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,409 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed floating point precision for relay graphs. i.e. turn a graph into fp16.
+ *
+ */
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
+// savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to
+// justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to
+// numerical reasons.
+enum MixedTypeConversionCategory : int {
+  MIXED_PRECISION_ALWAYS = 0,
+  MIXED_PRECISION_FOLLOW = 1,
+  MIXED_PRECISION_NEVER = 2
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// Return array is of type : [MixedTypeConversionCategory (int), String, String]
+// The fields are          : [ConversionCategory, accumulation_datatype, output_datatype]
+// Call is a call node, DataType is the mixed precision type
+using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc<Array<ObjectRef>(
+    const Call& call_node, const std::string& target_dtype_str)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+
+  // The target datatype we want to convert to e.g. FP16
+  const DataType mixed_precision_type;
+
+  // If false, throws a fatal error if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool ignore_missing_ops;
+
+  // If true, emits a warning if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool warn_missing_ops;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs cur_attrs = call->attrs;
+    if (cur_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = cur_attrs.as<Conv1DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv1DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DeformableConv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DenseAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<BatchMatmulAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }

Review comment:
       Makes me miss duck typing...

##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,409 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed floating point precision for relay graphs. i.e. turn a graph into fp16.
+ *
+ */
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
+// savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to
+// justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to
+// numerical reasons.
+enum MixedTypeConversionCategory : int {
+  MIXED_PRECISION_ALWAYS = 0,
+  MIXED_PRECISION_FOLLOW = 1,
+  MIXED_PRECISION_NEVER = 2
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// Return array is of type : [MixedTypeConversionCategory (int), String, String]
+// The fields are          : [ConversionCategory, accumulation_datatype, output_datatype]
+// Call is a call node, DataType is the mixed precision type
+using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc<Array<ObjectRef>(
+    const Call& call_node, const std::string& target_dtype_str)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+
+  // The target datatype we want to convert to e.g. FP16
+  const DataType mixed_precision_type;
+
+  // If false, throws a fatal error if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool ignore_missing_ops;
+
+  // If true, emits a warning if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool warn_missing_ops;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs cur_attrs = call->attrs;
+    if (cur_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = cur_attrs.as<Conv1DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv1DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DeformableConv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DenseAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<BatchMatmulAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = cur_attrs.as<InitOpAttrs>()) {
+        return ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return cur_attrs;
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    DataType cur_type = (attrs->out_dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->out_dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    DataType cur_type = (attrs->dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }

Review comment:
       This might not work for all models, especially models imported from TF that don't have all of the type information in a single function. I did something I don't like in DynamicToStatic (modifying the input module) to work around that problem: https://github.com/apache/tvm/blob/e4c76232a4c27860d1604878cb39fec3ba337e10/src/relay/transforms/dynamic_to_static.cc#L226-L246

##########
File path: python/tvm/relay/transform/transform.py
##########
@@ -1199,3 +1198,22 @@ def FakeQuantizationToInteger():
         The registered SimplifyExpr pass.
     """
     return _ffi_api.FakeQuantizationToInteger()
+
+
+def ToMixedPrecision(
+    mixed_precision_type="float16", ignore_missing_ops=True, warn_missing_ops=True
+):
+    """
+    Automatic mixed precision rewriter. Rewrite an FP32 relay graph into a version
+    where as many operations as possible are in the target mixed_precision_type.
+
+    Note this does mutate the original graph putting it in a bad state potentially.
+
+    TODO(AndrewZhaoLuo): don't mutate the original graph.

Review comment:
       I'm not sure I see why this is modifying the original graph? You don't seem to do anything to the IRModule.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on a change in pull request #8069: [Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r652181818



##########
File path: tests/python/relay/test_to_mixed_precision.py
##########
@@ -0,0 +1,446 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for testing ToMixedPrecision pass"""

Review comment:
       It takes around 3 seconds on my m1 mac.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on a change in pull request #8069: [Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r652211415



##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,409 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed floating point precision for relay graphs. i.e. turn a graph into fp16.
+ *
+ */
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
+// savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to
+// justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to
+// numerical reasons.
+enum MixedTypeConversionCategory : int {
+  MIXED_PRECISION_ALWAYS = 0,
+  MIXED_PRECISION_FOLLOW = 1,
+  MIXED_PRECISION_NEVER = 2
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// Return array is of type : [MixedTypeConversionCategory (int), String, String]
+// The fields are          : [ConversionCategory, accumulation_datatype, output_datatype]
+// Call is a call node, DataType is the mixed precision type
+using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc<Array<ObjectRef>(
+    const Call& call_node, const std::string& target_dtype_str)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+
+  // The target datatype we want to convert to e.g. FP16
+  const DataType mixed_precision_type;
+
+  // If false, throws a fatal error if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool ignore_missing_ops;
+
+  // If true, emits a warning if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool warn_missing_ops;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs cur_attrs = call->attrs;
+    if (cur_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = cur_attrs.as<Conv1DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv1DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DeformableConv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DenseAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<BatchMatmulAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = cur_attrs.as<InitOpAttrs>()) {
+        return ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return cur_attrs;
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    DataType cur_type = (attrs->out_dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->out_dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    DataType cur_type = (attrs->dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only target mixed precision type elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == mixed_precision_type;
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsMixedPrecisionType(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    if (expr_dtype == wanted_dtype) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;
+
+    // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
+    const ExprNode* new_expr_node = result.as<ExprNode>();
+    cast_nodes_cache[{new_expr_node, expr_dtype}] = expr;
+    return result;
+  }
+
+  Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
+    /* Helper for casting arguments to call_nodes handling all relevant cases. */
+    if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+      return CachedCast(expr, tensor_type->dtype, wanted_dtype);
+    } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
+        Expr tuple_element = GetField(expr, i);
+        Type tuple_element_dtype = (tuple_type->fields)[i];
+        Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(tuple_element);
+      }
+      return all_same ? expr : Tuple(new_expr);
+    } else {
+      LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
+      return expr;
+    }
+  }
+
+  std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args,
+                                                  const Array<Type>& cur_arg_types,
+                                                  const DataType& wanted_dtype) {
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    for (size_t i = 0; i < cur_args.size(); i++) {
+      Expr cur_arg = cur_args[i];
+      Type cur_arg_type = cur_arg_types[i];
+      Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+    }
+    return {new_args, new_arg_types};
+  }
+
+ public:
+  using MixedModeMutator::VisitExpr_;
+
+  explicit MixedPrecisionPass(DataType mixed_precision_type = DataType::Float(16),
+                              bool ignore_missing_ops = true, bool warn_missing_ops = true)
+      : MixedModeMutator(),
+        mixed_precision_type(mixed_precision_type),
+        ignore_missing_ops(ignore_missing_ops),
+        warn_missing_ops(warn_missing_ops) {
+    if (!mixed_precision_type.is_float() && !mixed_precision_type.is_bfloat16())
+      LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat16 got "
+                 << mixed_precision_type;
+  }
+
+  Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final {
+    const CallNode* post_call_node = post.as<CallNode>();
+    if (!post_call_node) {
+      LOG(FATAL) << "Expected a CallNode for the rewrite got " << post;
+    }
+
+    Expr cur_op = post_call_node->op;
+
+    // Get info on the operation being called:
+    // conversion category (int), accumulation dtype (str), output dtype (str)
+    MixedTypeConversionCategory initial_category;
+    DataType accumulation_dtype, output_dtype;
+    if (cur_op.as<FunctionNode>()) {
+      // Avoid messing with functions to avoid changing signature
+      initial_category = MIXED_PRECISION_NEVER;
+      accumulation_dtype = DataType::Float(32);
+      output_dtype = DataType::Float(32);
+    } else if (cur_op.as<OpNode>()) {
+      static auto attr_map =
+          Op::GetAttrMap<FTVMMixedPrecisionConversionType>("FTVMMixedPrecisionConversionType");
+      Op op = Downcast<Op>(cur_op);
+      if (attr_map.count(op)) {
+        // Calculate the conversion category and dtypes from registered attribute.
+        FTVMMixedPrecisionConversionType func = attr_map[op];
+        Array<ObjectRef> op_descriptor =
+            func(GetRef<Call>(pre_call_node), DLDataType2String(mixed_precision_type));
+
+        int64_t op_conversion_type = Downcast<Integer>(op_descriptor[0])->value;
+        initial_category = static_cast<MixedTypeConversionCategory>(op_conversion_type);
+        accumulation_dtype = DataType(String2DLDataType(Downcast<String>(op_descriptor[1])));
+        output_dtype = DataType(String2DLDataType(Downcast<String>(op_descriptor[2])));
+      } else {
+        if (!ignore_missing_ops) LOG(FATAL) << "Op " << op->name << " not in conversion lists!";
+        if (warn_missing_ops) LOG(WARNING) << "Op " << op->name << " not in conversion lists!";

Review comment:
       Done, it's now a single flag `missing_op_mode`




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] comaniac commented on a change in pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r648599365



##########
File path: python/tvm/relay/transform/transform.py
##########
@@ -1199,3 +1198,20 @@ def FakeQuantizationToInteger():
         The registered SimplifyExpr pass.
     """
     return _ffi_api.FakeQuantizationToInteger()
+
+
+def AMPRewrite():

Review comment:
       I don't think AutoCast doesn't capture the nature. For example: https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] comaniac commented on a change in pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r648604259



##########
File path: src/relay/transforms/fp32_to_fp16.h
##########
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file fp32_to_fp16.h
+ * \brief Utilities and common types used for FP32->FP16 pass.
+ */
+#ifndef TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_
+#define TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_
+
+#include <tvm/ir/op.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/function.h>
+
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relay {
+
+struct FP16OpDType {
+  DataType accumulation_dtype;
+  DataType output_dtype;
+};
+
+// GREEN colored ops should always be done in FP16 due to the speed and memory savings
+// GRAY colored ops can be done in FP16 but don't have speedups to justify a dedicated cast.
+// RED colored ops should not be done in FP16 due to numerical reasons.
+enum FP16ConversionCategory { RED, GRAY, GREEN };
+
+using OpStringSet = std::unordered_set<std::string>;
+
+// Default lists inspired from TF's classifications:

Review comment:
       You can refer to the design document of layout conversion pass: https://tvm.apache.org/docs/dev/convert_layout.html. It's actually not hard to take rules from Python for this design.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo edited a comment on pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo edited a comment on pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#issuecomment-858045479


   Hey folks, covered the simple changes requested, here is the list of more involved changes along with the associated reviewer. Several of these changes were planned to be future PRs but it might be best to just commit this correctly the first time (since it doesn't really touch other files):
   - [x] Support other floating point types out of the box (e.g. bfloat16)
   - [ ] Naming of things (pass, GREEN/RED/GRAY, etc.)
   - [ ] Python interface for Coloring/Accumulation logic
   - [ ] How to register ops for coloring
   - [x] MixedModeMutator to avoid stackoverflow
   
   Let me know if I missed anything


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo edited a comment on pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo edited a comment on pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#issuecomment-858045479






-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo edited a comment on pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo edited a comment on pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#issuecomment-858045479






-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on a change in pull request #8069: [Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r652190815



##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,409 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed floating point precision for relay graphs. i.e. turn a graph into fp16.
+ *
+ */
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
+// savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to
+// justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to
+// numerical reasons.
+enum MixedTypeConversionCategory : int {
+  MIXED_PRECISION_ALWAYS = 0,
+  MIXED_PRECISION_FOLLOW = 1,
+  MIXED_PRECISION_NEVER = 2
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// Return array is of type : [MixedTypeConversionCategory (int), String, String]
+// The fields are          : [ConversionCategory, accumulation_datatype, output_datatype]
+// Call is a call node, DataType is the mixed precision type
+using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc<Array<ObjectRef>(
+    const Call& call_node, const std::string& target_dtype_str)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+
+  // The target datatype we want to convert to e.g. FP16
+  const DataType mixed_precision_type;
+
+  // If false, throws a fatal error if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool ignore_missing_ops;
+
+  // If true, emits a warning if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool warn_missing_ops;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs cur_attrs = call->attrs;
+    if (cur_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = cur_attrs.as<Conv1DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv1DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DeformableConv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DenseAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<BatchMatmulAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = cur_attrs.as<InitOpAttrs>()) {
+        return ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return cur_attrs;
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    DataType cur_type = (attrs->out_dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->out_dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    DataType cur_type = (attrs->dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only target mixed precision type elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == mixed_precision_type;
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsMixedPrecisionType(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    if (expr_dtype == wanted_dtype) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;
+
+    // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
+    const ExprNode* new_expr_node = result.as<ExprNode>();
+    cast_nodes_cache[{new_expr_node, expr_dtype}] = expr;
+    return result;
+  }
+
+  Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
+    /* Helper for casting arguments to call_nodes handling all relevant cases. */
+    if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+      return CachedCast(expr, tensor_type->dtype, wanted_dtype);
+    } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
+        Expr tuple_element = GetField(expr, i);
+        Type tuple_element_dtype = (tuple_type->fields)[i];
+        Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(tuple_element);
+      }
+      return all_same ? expr : Tuple(new_expr);
+    } else {
+      LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
+      return expr;
+    }
+  }
+
+  std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args,
+                                                  const Array<Type>& cur_arg_types,
+                                                  const DataType& wanted_dtype) {
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    for (size_t i = 0; i < cur_args.size(); i++) {
+      Expr cur_arg = cur_args[i];
+      Type cur_arg_type = cur_arg_types[i];
+      Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+    }
+    return {new_args, new_arg_types};
+  }
+
+ public:
+  using MixedModeMutator::VisitExpr_;
+
+  explicit MixedPrecisionPass(DataType mixed_precision_type = DataType::Float(16),
+                              bool ignore_missing_ops = true, bool warn_missing_ops = true)
+      : MixedModeMutator(),
+        mixed_precision_type(mixed_precision_type),
+        ignore_missing_ops(ignore_missing_ops),
+        warn_missing_ops(warn_missing_ops) {
+    if (!mixed_precision_type.is_float() && !mixed_precision_type.is_bfloat16())
+      LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat16 got "
+                 << mixed_precision_type;
+  }
+
+  Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final {
+    const CallNode* post_call_node = post.as<CallNode>();
+    if (!post_call_node) {
+      LOG(FATAL) << "Expected a CallNode for the rewrite got " << post;
+    }
+
+    Expr cur_op = post_call_node->op;
+
+    // Get info on the operation being called:
+    // conversion category (int), accumulation dtype (str), output dtype (str)
+    MixedTypeConversionCategory initial_category;
+    DataType accumulation_dtype, output_dtype;
+    if (cur_op.as<FunctionNode>()) {
+      // Avoid messing with functions to avoid changing signature
+      initial_category = MIXED_PRECISION_NEVER;
+      accumulation_dtype = DataType::Float(32);
+      output_dtype = DataType::Float(32);
+    } else if (cur_op.as<OpNode>()) {
+      static auto attr_map =
+          Op::GetAttrMap<FTVMMixedPrecisionConversionType>("FTVMMixedPrecisionConversionType");
+      Op op = Downcast<Op>(cur_op);
+      if (attr_map.count(op)) {
+        // Calculate the conversion category and dtypes from registered attribute.
+        FTVMMixedPrecisionConversionType func = attr_map[op];
+        Array<ObjectRef> op_descriptor =
+            func(GetRef<Call>(pre_call_node), DLDataType2String(mixed_precision_type));
+
+        int64_t op_conversion_type = Downcast<Integer>(op_descriptor[0])->value;
+        initial_category = static_cast<MixedTypeConversionCategory>(op_conversion_type);
+        accumulation_dtype = DataType(String2DLDataType(Downcast<String>(op_descriptor[1])));
+        output_dtype = DataType(String2DLDataType(Downcast<String>(op_descriptor[2])));
+      } else {
+        if (!ignore_missing_ops) LOG(FATAL) << "Op " << op->name << " not in conversion lists!";
+        if (warn_missing_ops) LOG(WARNING) << "Op " << op->name << " not in conversion lists!";
+
+        // If not registered, by default assume is a generic FOLLOW operation.
+        initial_category = MIXED_PRECISION_FOLLOW;
+        accumulation_dtype = DataType::Float(16);
+        output_dtype = DataType::Float(16);

Review comment:
       Yes oops!




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on a change in pull request #8069: [Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r652191690



##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,409 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed floating point precision for relay graphs. i.e. turn a graph into fp16.
+ *
+ */
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
+// savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to
+// justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to
+// numerical reasons.
+enum MixedTypeConversionCategory : int {
+  MIXED_PRECISION_ALWAYS = 0,
+  MIXED_PRECISION_FOLLOW = 1,
+  MIXED_PRECISION_NEVER = 2
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// Return array is of type : [MixedTypeConversionCategory (int), String, String]
+// The fields are          : [ConversionCategory, accumulation_datatype, output_datatype]
+// Call is a call node, DataType is the mixed precision type
+using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc<Array<ObjectRef>(
+    const Call& call_node, const std::string& target_dtype_str)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+
+  // The target datatype we want to convert to e.g. FP16
+  const DataType mixed_precision_type;
+
+  // If false, throws a fatal error if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool ignore_missing_ops;
+
+  // If true, emits a warning if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool warn_missing_ops;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs cur_attrs = call->attrs;
+    if (cur_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = cur_attrs.as<Conv1DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv1DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DeformableConv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DenseAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<BatchMatmulAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = cur_attrs.as<InitOpAttrs>()) {
+        return ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return cur_attrs;
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    DataType cur_type = (attrs->out_dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->out_dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    DataType cur_type = (attrs->dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only target mixed precision type elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == mixed_precision_type;
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsMixedPrecisionType(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    if (expr_dtype == wanted_dtype) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;
+
+    // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
+    const ExprNode* new_expr_node = result.as<ExprNode>();
+    cast_nodes_cache[{new_expr_node, expr_dtype}] = expr;
+    return result;
+  }
+
+  Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
+    /* Helper for casting arguments to call_nodes handling all relevant cases. */
+    if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+      return CachedCast(expr, tensor_type->dtype, wanted_dtype);
+    } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
+        Expr tuple_element = GetField(expr, i);
+        Type tuple_element_dtype = (tuple_type->fields)[i];
+        Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(tuple_element);
+      }
+      return all_same ? expr : Tuple(new_expr);
+    } else {
+      LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
+      return expr;
+    }
+  }
+
+  std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args,
+                                                  const Array<Type>& cur_arg_types,
+                                                  const DataType& wanted_dtype) {
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    for (size_t i = 0; i < cur_args.size(); i++) {
+      Expr cur_arg = cur_args[i];
+      Type cur_arg_type = cur_arg_types[i];
+      Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+    }
+    return {new_args, new_arg_types};
+  }
+
+ public:
+  using MixedModeMutator::VisitExpr_;
+
+  explicit MixedPrecisionPass(DataType mixed_precision_type = DataType::Float(16),
+                              bool ignore_missing_ops = true, bool warn_missing_ops = true)
+      : MixedModeMutator(),
+        mixed_precision_type(mixed_precision_type),
+        ignore_missing_ops(ignore_missing_ops),
+        warn_missing_ops(warn_missing_ops) {
+    if (!mixed_precision_type.is_float() && !mixed_precision_type.is_bfloat16())
+      LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat16 got "
+                 << mixed_precision_type;
+  }
+
+  Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final {
+    const CallNode* post_call_node = post.as<CallNode>();
+    if (!post_call_node) {
+      LOG(FATAL) << "Expected a CallNode for the rewrite got " << post;
+    }
+
+    Expr cur_op = post_call_node->op;
+
+    // Get info on the operation being called:
+    // conversion category (int), accumulation dtype (str), output dtype (str)
+    MixedTypeConversionCategory initial_category;
+    DataType accumulation_dtype, output_dtype;
+    if (cur_op.as<FunctionNode>()) {
+      // Avoid messing with functions to avoid changing signature
+      initial_category = MIXED_PRECISION_NEVER;
+      accumulation_dtype = DataType::Float(32);
+      output_dtype = DataType::Float(32);
+    } else if (cur_op.as<OpNode>()) {
+      static auto attr_map =
+          Op::GetAttrMap<FTVMMixedPrecisionConversionType>("FTVMMixedPrecisionConversionType");
+      Op op = Downcast<Op>(cur_op);
+      if (attr_map.count(op)) {
+        // Calculate the conversion category and dtypes from registered attribute.
+        FTVMMixedPrecisionConversionType func = attr_map[op];
+        Array<ObjectRef> op_descriptor =
+            func(GetRef<Call>(pre_call_node), DLDataType2String(mixed_precision_type));
+
+        int64_t op_conversion_type = Downcast<Integer>(op_descriptor[0])->value;
+        initial_category = static_cast<MixedTypeConversionCategory>(op_conversion_type);
+        accumulation_dtype = DataType(String2DLDataType(Downcast<String>(op_descriptor[1])));
+        output_dtype = DataType(String2DLDataType(Downcast<String>(op_descriptor[2])));
+      } else {
+        if (!ignore_missing_ops) LOG(FATAL) << "Op " << op->name << " not in conversion lists!";
+        if (warn_missing_ops) LOG(WARNING) << "Op " << op->name << " not in conversion lists!";

Review comment:
       So you suggest a single flag `verbose` which just controls emission of warnings?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi edited a comment on pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
masahi edited a comment on pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#issuecomment-859341545






-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo edited a comment on pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo edited a comment on pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#issuecomment-858045479


   Hey folks, covered the simple changes requested, here is the list of more involved changes along with the associated reviewer. Several of these changes were planned to be future PRs but it might be best to just commit this correctly the first time (since it doesn't really touch other files):
   - [ ] Support other floating point types out of the box (e.g. bfloat16)
   - [ ] Naming of things (pass, GREEN/RED/GRAY, etc.)
   - [ ] Python interface for Coloring/Accumulation logic
   - [ ] How to register ops for coloring
   - [ ] MixedModeMutator to avoid stackoverflow
   
   Let me know if I missed anything


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] comaniac commented on a change in pull request #8069: [Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r652216408



##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,420 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed floating point precision for relay graphs. i.e. turn a graph into fp16.
+ *
+ */
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
+// savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to
+// justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to
+// numerical reasons.
+enum MixedTypeConversionCategory : int {
+  MIXED_PRECISION_ALWAYS = 0,
+  MIXED_PRECISION_FOLLOW = 1,
+  MIXED_PRECISION_NEVER = 2
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// Return array is of type : [MixedTypeConversionCategory (int), String, String]
+// The fields are          : [ConversionCategory, accumulation_datatype, output_datatype]
+// Call is a call node, DataType is the mixed precision type
+using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc<Array<ObjectRef>(
+    const Call& call_node, const std::string& target_dtype_str)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache_;
+
+  /*! \brief The target datatype we want to convert to e.g. FP16 */
+  const DataType mixed_precision_type;

Review comment:
       Please check through all changes you've made.
   ```suggestion
     const DataType mixed_precision_type_;
   ```

##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,420 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed floating point precision for relay graphs. i.e. turn a graph into fp16.
+ *
+ */
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
+// savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to
+// justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to
+// numerical reasons.
+enum MixedTypeConversionCategory : int {
+  MIXED_PRECISION_ALWAYS = 0,
+  MIXED_PRECISION_FOLLOW = 1,
+  MIXED_PRECISION_NEVER = 2
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// Return array is of type : [MixedTypeConversionCategory (int), String, String]
+// The fields are          : [ConversionCategory, accumulation_datatype, output_datatype]
+// Call is a call node, DataType is the mixed precision type
+using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc<Array<ObjectRef>(
+    const Call& call_node, const std::string& target_dtype_str)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache_;

Review comment:
       docstring?

##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,420 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed floating point precision for relay graphs. i.e. turn a graph into fp16.
+ *
+ */
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
+// savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to
+// justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to
+// numerical reasons.
+enum MixedTypeConversionCategory : int {
+  MIXED_PRECISION_ALWAYS = 0,
+  MIXED_PRECISION_FOLLOW = 1,
+  MIXED_PRECISION_NEVER = 2
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// Return array is of type : [MixedTypeConversionCategory (int), String, String]
+// The fields are          : [ConversionCategory, accumulation_datatype, output_datatype]
+// Call is a call node, DataType is the mixed precision type
+using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc<Array<ObjectRef>(
+    const Call& call_node, const std::string& target_dtype_str)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache_;
+
+  /*! \brief The target datatype we want to convert to e.g. FP16 */
+  const DataType mixed_precision_type;
+
+  // Map of Ops with no associated FTVMMixedPrecisionConversionType to the times they were
+  // encountered. Used for emitting warnings on missing ops in the pass.
+  std::unordered_map<std::string, int> missing_ops;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs cur_attrs = call->attrs;
+    if (cur_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = cur_attrs.as<Conv1DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv1DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DeformableConv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DenseAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<BatchMatmulAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = cur_attrs.as<InitOpAttrs>()) {
+        return ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return cur_attrs;
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    DataType cur_type = (attrs->out_dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->out_dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    DataType cur_type = (attrs->dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only target mixed precision type elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == mixed_precision_type;
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsMixedPrecisionType(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    if (expr_dtype == wanted_dtype) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    CHECK(expr_node) << "Non-expression node found in cast: " << expr;
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache_.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache_.end()) {
+      return search->second;
+    }
+
+    Expr result = Cast(expr, wanted_dtype);
+    cast_nodes_cache_[{expr_node, wanted_dtype}] = result;
+
+    // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
+    const ExprNode* new_expr_node = result.as<ExprNode>();
+    cast_nodes_cache_[{new_expr_node, expr_dtype}] = expr;
+    return result;
+  }
+
+  Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
+    /* Helper for casting arguments to call_nodes handling all relevant cases. */
+    if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+      return CachedCast(expr, tensor_type->dtype, wanted_dtype);
+    } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
+        Expr tuple_element = GetField(expr, i);
+        Type tuple_element_dtype = (tuple_type->fields)[i];
+        Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(tuple_element);
+      }
+      return all_same ? expr : Tuple(new_expr);
+    }
+    CHECK(0) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
+    return expr;
+  }
+
+  std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args,
+                                                  const Array<Type>& cur_arg_types,
+                                                  const DataType& wanted_dtype) {
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    for (size_t i = 0; i < cur_args.size(); i++) {
+      Expr cur_arg = cur_args[i];
+      Type cur_arg_type = cur_arg_types[i];
+      Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+    }
+    return {new_args, new_arg_types};
+  }
+
+ public:
+  using MixedModeMutator::VisitExpr_;
+
+  explicit MixedPrecisionPass(DataType mixed_precision_type = DataType::Float(16))
+      : MixedModeMutator(), mixed_precision_type(mixed_precision_type) {
+    if (!mixed_precision_type.is_float() && !mixed_precision_type.is_bfloat16()) {
+      LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat16, but got "
+                 << mixed_precision_type;
+    }
+  }
+
+  Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final {
+    const CallNode* post_call_node = post.as<CallNode>();
+    CHECK(post_call_node) << "Expected a CallNode, but got " << post;
+
+    Expr cur_op = post_call_node->op;
+
+    // Get info on the operation being called:
+    // conversion category (int), accumulation dtype (str), output dtype (str)
+    MixedTypeConversionCategory initial_category;
+    DataType accumulation_dtype, output_dtype;
+    if (cur_op.as<FunctionNode>()) {
+      // Avoid messing with functions to avoid changing signature
+      initial_category = MIXED_PRECISION_NEVER;
+      accumulation_dtype = DataType::Float(32);
+      output_dtype = DataType::Float(32);
+    } else if (cur_op.as<OpNode>()) {
+      static auto attr_map =
+          Op::GetAttrMap<FTVMMixedPrecisionConversionType>("FTVMMixedPrecisionConversionType");
+      Op op = Downcast<Op>(cur_op);
+      if (attr_map.count(op)) {
+        // Calculate the conversion category and dtypes from registered attribute.
+        FTVMMixedPrecisionConversionType func = attr_map[op];
+        Array<ObjectRef> op_descriptor =
+            func(GetRef<Call>(pre_call_node), DLDataType2String(mixed_precision_type));
+        ICHECK(op_descriptor.size() == 3)
+            << "got the wrong number of returned arguments (expected 3 got " << op_descriptor.size()
+            << ") from FTVMMixedPrecisionConversionType for " << AsText(op, false);
+
+        int64_t op_conversion_type = Downcast<Integer>(op_descriptor[0])->value;
+        initial_category = static_cast<MixedTypeConversionCategory>(op_conversion_type);
+        accumulation_dtype = DataType(String2DLDataType(Downcast<String>(op_descriptor[1])));
+        output_dtype = DataType(String2DLDataType(Downcast<String>(op_descriptor[2])));
+      } else {
+        missing_ops[op->name] += 1;
+
+        // If not registered, by default assume is a generic FOLLOW operation.
+        initial_category = MIXED_PRECISION_FOLLOW;
+        accumulation_dtype = mixed_precision_type;
+        output_dtype = mixed_precision_type;
+      }
+    } else {
+      LOG(FATAL) << "Unsupported op type in CallNode: " << pre_call_node->op;
+    }
+
+    // First check if all the new mutated args are in lower precision form
+    Array<Type> cur_arg_types;
+    bool all_args_mixed_type_compatible = true;
+    for (Expr arg : post_call_node->args) {
+      Type cur_arg_type = GetType(arg);
+      cur_arg_types.push_back(cur_arg_type);
+
+      if (initial_category == MIXED_PRECISION_FOLLOW && all_args_mixed_type_compatible) {
+        // We can cast Vars and Constants to the right types so don't care about the types.
+        bool is_mixed_type_compatible = IsMixedPrecisionType(cur_arg_type, true) ||
+                                        arg->IsInstance<VarNode>() ||
+                                        arg->IsInstance<ConstantNode>();
+        all_args_mixed_type_compatible &= is_mixed_type_compatible;
+      }
+    }
+
+    // Determine the final category we want for conversion
+    MixedTypeConversionCategory final_category = initial_category;
+    if (initial_category == MIXED_PRECISION_FOLLOW) {
+      final_category =
+          all_args_mixed_type_compatible ? MIXED_PRECISION_ALWAYS : MIXED_PRECISION_NEVER;
+    }
+
+    // Create the new arguments to the call.
+    DataType wanted_arg_dtypes =
+        final_category == MIXED_PRECISION_ALWAYS ? mixed_precision_type : DataType::Float(32);
+    auto call_args_and_types = CastAllArgs(post_call_node->args, cur_arg_types, wanted_arg_dtypes);
+    Array<Expr> new_args = call_args_and_types.first;
+    Array<Type> new_arg_types;
+
+    if (pre_call_node->op.as<FunctionNode>()) {
+      // Function Nodes don't store type info in the Call, it should be a []
+      new_arg_types = pre_call_node->type_args;
+    } else {
+      new_arg_types = call_args_and_types.second;
+    }
+
+    // Finally create the new attributes.
+    if (final_category == MIXED_PRECISION_ALWAYS) {
+      Attrs new_attrs = GetNewAttrs(pre_call_node, accumulation_dtype);
+      Expr output = Call(cur_op, new_args, new_attrs, new_arg_types, pre_call_node->span);
+      if (accumulation_dtype != output_dtype) {
+        output = CastArg(output, GetType(output), output_dtype);
+      }
+      return output;
+    }
+
+    return Call(cur_op, new_args, pre_call_node->attrs, new_arg_types, pre_call_node->span);
+  }
+
+  Expr VisitExpr_(const FunctionNode* func) final {
+    // Erase the ret_type annotation and let the normal pass recalculate
+    const_cast<FunctionNode*>(func)->ret_type = Type(nullptr);
+    return ExprMutator::VisitExpr_(func);
+  }
+
+  Expr VisitExpr_(const LetNode* op) final {
+    // First convert as much of the bound computation to lower precision as possible
+    Expr value = this->Mutate(op->value);
+
+    // Then rewrite the var type and associated expression
+    Var var = Downcast<Var>(this->Mutate(op->var));
+    VarNode* mutable_var = const_cast<VarNode*>((op->var).as<VarNode>());
+    mutable_var->type_annotation = GetType(value);
+    mutable_var->checked_type_ = mutable_var->type_annotation;
+
+    // Mutate body last as it may depend on previous results
+    Expr body = this->Mutate(op->body);
+    return Let(var, value, body, op->span);
+  }
+
+  // To access map of ops not registered for error reporting
+  friend Expr ToMixedPrecision(const Expr& expr, const DataType& mixed_precision_type,
+                               int missing_op_mode);
+};
+
+Expr ToMixedPrecision(const Expr& expr, const DataType& mixed_precision_type, int missing_op_mode) {
+  /*
+  missing_op_mode:
+
+  0: Does not allow any missing ops. Will throw errors and terminate the pass when encountering any.
+  1: Allow missing ops but throw warnings.
+  2: Allow missing ops and silently ignore them.
+  */
+  ICHECK(missing_op_mode >= 0 && missing_op_mode <= 2)
+      << " missing_op_mode must be either 0, 1, or 2 got " << missing_op_mode;
+
+  MixedPrecisionPass converter = MixedPrecisionPass(mixed_precision_type);
+  auto result = converter.Mutate(expr);
+
+  for (auto it = converter.missing_ops.begin();
+       missing_op_mode != 2 && it != converter.missing_ops.end(); it++) {
+    std::string op_name = it->first;
+    int appear_count = it->second;
+
+    LOG(WARNING) << "Op \"" << op_name << "\" not registered "
+                 << "FTVMMixedPrecisionConversionType appears " << appear_count << " in graph.";
+  }
+
+  if (converter.missing_ops.size() != 0 && missing_op_mode == 0) {
+    CHECK(0) << "Missing ops were found, please fix!";

Review comment:
       I think this is not supposed to be fixed by end-users. We should direct them to the discuss forum to post what is missing, or provide the tutorial link of AMP pass in the future to help them fix it. cc @areusch 
   ```suggestion
       CHECK(0) << "Missing ops were found";
   ```

##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,420 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed floating point precision for relay graphs. i.e. turn a graph into fp16.
+ *
+ */
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
+// savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to
+// justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to
+// numerical reasons.
+enum MixedTypeConversionCategory : int {
+  MIXED_PRECISION_ALWAYS = 0,
+  MIXED_PRECISION_FOLLOW = 1,
+  MIXED_PRECISION_NEVER = 2
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// Return array is of type : [MixedTypeConversionCategory (int), String, String]
+// The fields are          : [ConversionCategory, accumulation_datatype, output_datatype]
+// Call is a call node, DataType is the mixed precision type
+using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc<Array<ObjectRef>(
+    const Call& call_node, const std::string& target_dtype_str)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache_;
+
+  /*! \brief The target datatype we want to convert to e.g. FP16 */
+  const DataType mixed_precision_type;
+
+  // Map of Ops with no associated FTVMMixedPrecisionConversionType to the times they were
+  // encountered. Used for emitting warnings on missing ops in the pass.
+  std::unordered_map<std::string, int> missing_ops;

Review comment:
       ```suggestion
     /*! \brief Map of Ops with no associated FTVMMixedPrecisionConversionType to the times they were
      * encountered. Used for emitting warnings on missing ops in the pass.
      */
     std::unordered_map<std::string, int> missing_ops_;
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #8069: [Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#issuecomment-864694497


   Thanks @AndrewZhaoLuo for the great work, and everyone for reviews!!
   
   I'll follow up with CUDA and OpenCL support.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] comaniac commented on a change in pull request #8069: [Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r653814924



##########
File path: python/tvm/relay/transform/mixed_precision.py
##########
@@ -0,0 +1,177 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=line-too-long,unused-argument
+"""Default behavior for ops in mixed_precision pass. Import this file to use."""
+from typing import List
+
+from tvm import relay
+from tvm.relay.op import register_mixed_precision_conversion
+
+# MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
+# savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to
+# justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to
+# numerical reasons.
+MIXED_PRECISION_ALWAYS = 0
+MIXED_PRECISION_FOLLOW = 1
+MIXED_PRECISION_NEVER = 2
+
+# Default lists inspired from TF's classifications:
+# github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h
+# They have a bias toward Nvidia Tensor Cores so modify lists per your hardware choice.
+DEFAULT_ALWAYS_LIST = [
+    "nn.conv1d",
+    "nn.conv2d",
+    "nn.conv3d",
+    "nn.conv1d_transpose",
+    "nn.conv2d_transpose",
+    "nn.conv3d_transpose",
+    "nn.dense",
+    # "nn.batch_matmul", # Handled by a special case
+]
+DEFAULT_FOLLOW_LIST = [
+    # These ops add new data or change shape
+    "nn.pad",
+    "nn.batch_flatten",
+    "concatenate",
+    "zeros",
+    "split",
+    "squeeze",
+    "transpose",
+    "expand_dims",
+    "reshape",
+    "dyn.reshape",
+    "broadcast_to_like",
+    "dyn.broadcast_to",
+    "strided_slice",
+    "dyn.strided_slice",
+    "take",
+    "argwhere",
+    "where",
+    "tile",
+    "dyn.tile",
+    "scatter",
+    "full",
+    "dyn.full",
+    # Comparison
+    "less",
+    "greater",
+    "less_equal",
+    "greater_equal",
+    # By definition copy and cast will depend on inputs for output.
+    "copy",
+    "cast",
+    "cast_like",
+    # Simple arithmetic
+    "add",
+    "subtract",
+    "multiply",
+    "divide",
+    "nn.bias_add",
+    "nn.batch_norm",
+    "sum",
+    "mean",
+    "sqrt",
+    "shape_of",
+    # Simple activations
+    "max",
+    "min",
+    "maximum",
+    "minimum",
+    "nn.relu",
+    "nn.leaky_relu",
+    "nn.prelu",
+    "nn.dropout",
+    # Complicated activations which saturate in a narrow range
+    "sigmoid",
+    "tanh",
+    # Pooling operations
+    "nn.max_pool1d",
+    "nn.max_pool2d",
+    "nn.max_pool3d",
+    "nn.avg_pool1d",
+    "nn.avg_pool2d",
+    "nn.avg_pool3d",
+    # "nn.global_max_pool1d", # does not exist yet
+    "nn.global_max_pool2d",
+    # "nn.global_max_pool3d", # does not exist yet
+    # "nn.global_avg_pool1d", # does not exist yet
+    "nn.global_avg_pool2d",
+    # "nn.global_avg_pool3d", # does not exist yet
+    "nn.adaptive_max_pool1d",
+    "nn.adaptive_max_pool2d",
+    "nn.adaptive_max_pool3d",
+    "nn.adaptive_avg_pool1d",
+    "nn.adaptive_avg_pool2d",
+    "nn.adaptive_avg_pool3d",
+]
+DEFAULT_NEVER_LIST = [
+    # In general if |f(x)| >> |x| for expected inputs then put the op here.
+    "exp",
+    "power",
+    "nn.cross_entropy",
+    "nn.cross_entropy_with_logits",
+    "nn.softmax",
+    "nn.l2_normalize",
+    # Error function doesn't seem to be able to be lowered into fp16 version in llvm.
+    # Move to follow list when it does.
+    "erf",
+]
+
+
+# Returns a decorator which registers for every given op, the function under FTVMMixedPrecisionConversionType
+def register_func_to_op_list(list_ops):
+    def decorator(func):
+        for op_name in list_ops:
+            register_mixed_precision_conversion(op_name, func=func)
+
+    return decorator
+
+
+def get_generic_out_dtypes(call_node: relay.Call, mixed_precision_type: str) -> List[str]:
+    # Assume support accumulation dtypes <---> has out_dtype attr

Review comment:
       Why this assumption?

##########
File path: python/tvm/relay/transform/mixed_precision.py
##########
@@ -0,0 +1,177 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=line-too-long,unused-argument
+"""Default behavior for ops in mixed_precision pass. Import this file to use."""
+from typing import List
+
+from tvm import relay
+from tvm.relay.op import register_mixed_precision_conversion
+
+# MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
+# savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to
+# justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to
+# numerical reasons.
+MIXED_PRECISION_ALWAYS = 0
+MIXED_PRECISION_FOLLOW = 1
+MIXED_PRECISION_NEVER = 2
+
+# Default lists inspired from TF's classifications:
+# github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h
+# They have a bias toward Nvidia Tensor Cores so modify lists per your hardware choice.
+DEFAULT_ALWAYS_LIST = [
+    "nn.conv1d",
+    "nn.conv2d",
+    "nn.conv3d",
+    "nn.conv1d_transpose",
+    "nn.conv2d_transpose",
+    "nn.conv3d_transpose",
+    "nn.dense",
+    # "nn.batch_matmul", # Handled by a special case
+]
+DEFAULT_FOLLOW_LIST = [
+    # These ops add new data or change shape
+    "nn.pad",
+    "nn.batch_flatten",
+    "concatenate",
+    "zeros",
+    "split",
+    "squeeze",
+    "transpose",
+    "expand_dims",
+    "reshape",
+    "dyn.reshape",
+    "broadcast_to_like",
+    "dyn.broadcast_to",
+    "strided_slice",
+    "dyn.strided_slice",
+    "take",
+    "argwhere",
+    "where",
+    "tile",
+    "dyn.tile",
+    "scatter",
+    "full",
+    "dyn.full",
+    # Comparison
+    "less",
+    "greater",
+    "less_equal",
+    "greater_equal",
+    # By definition copy and cast will depend on inputs for output.
+    "copy",
+    "cast",
+    "cast_like",
+    # Simple arithmetic
+    "add",
+    "subtract",
+    "multiply",
+    "divide",
+    "nn.bias_add",
+    "nn.batch_norm",
+    "sum",
+    "mean",
+    "sqrt",
+    "shape_of",
+    # Simple activations
+    "max",
+    "min",
+    "maximum",
+    "minimum",
+    "nn.relu",
+    "nn.leaky_relu",
+    "nn.prelu",
+    "nn.dropout",
+    # Complicated activations which saturate in a narrow range
+    "sigmoid",
+    "tanh",
+    # Pooling operations
+    "nn.max_pool1d",
+    "nn.max_pool2d",
+    "nn.max_pool3d",
+    "nn.avg_pool1d",
+    "nn.avg_pool2d",
+    "nn.avg_pool3d",
+    # "nn.global_max_pool1d", # does not exist yet
+    "nn.global_max_pool2d",
+    # "nn.global_max_pool3d", # does not exist yet
+    # "nn.global_avg_pool1d", # does not exist yet
+    "nn.global_avg_pool2d",
+    # "nn.global_avg_pool3d", # does not exist yet
+    "nn.adaptive_max_pool1d",
+    "nn.adaptive_max_pool2d",
+    "nn.adaptive_max_pool3d",
+    "nn.adaptive_avg_pool1d",
+    "nn.adaptive_avg_pool2d",
+    "nn.adaptive_avg_pool3d",
+]
+DEFAULT_NEVER_LIST = [
+    # In general if |f(x)| >> |x| for expected inputs then put the op here.
+    "exp",
+    "power",
+    "nn.cross_entropy",
+    "nn.cross_entropy_with_logits",
+    "nn.softmax",
+    "nn.l2_normalize",
+    # Error function doesn't seem to be able to be lowered into fp16 version in llvm.
+    # Move to follow list when it does.
+    "erf",
+]
+
+
+# Returns a decorator which registers for every given op, the function under FTVMMixedPrecisionConversionType
+def register_func_to_op_list(list_ops):
+    def decorator(func):
+        for op_name in list_ops:
+            register_mixed_precision_conversion(op_name, func=func)
+
+    return decorator
+
+
+def get_generic_out_dtypes(call_node: relay.Call, mixed_precision_type: str) -> List[str]:
+    # Assume support accumulation dtypes <---> has out_dtype attr

Review comment:
       Found it. Please port the comments here to provide the full content to people who may want to resolve this assumption in the future.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on a change in pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r648596936



##########
File path: python/tvm/relay/transform/transform.py
##########
@@ -1199,3 +1198,20 @@ def FakeQuantizationToInteger():
         The registered SimplifyExpr pass.
     """
     return _ffi_api.FakeQuantizationToInteger()
+
+
+def AMPRewrite():

Review comment:
       I disagree. All the passes have names which are verbs which describe what they do while `AMP` is a noun. Maybe `AutoCast` would be better but it doesn't capture the mixed precision nature.
   
   Maybe `ToMixedPrecision` would be a better name?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#issuecomment-859715742


   > @AndrewZhaoLuo oh by M1 do you mean its cpu or gpu (metal)?
   
   I think it's CPU. Here's the benchmarking + tuning script I used: https://github.com/AndrewZhaoLuo/TVM-Sandbox/blob/a3c4b6b2235afb1826b237af1136bbb9539c9ff9/fp16_pass/benchmark_m1_mac_fp16.py
   
   The other models you have are interesting, I think the SSD model I used has combined NMS. At least, it returns variable length tensors representing different numbers of objects detected.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on a change in pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r651430536



##########
File path: src/relay/transforms/fp32_to_fp16.h
##########
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file fp32_to_fp16.h
+ * \brief Utilities and common types used for FP32->FP16 pass.
+ */
+#ifndef TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_
+#define TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_
+
+#include <tvm/ir/op.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/function.h>
+
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relay {
+
+struct FP16OpDType {
+  DataType accumulation_dtype;
+  DataType output_dtype;
+};
+
+// GREEN colored ops should always be done in FP16 due to the speed and memory savings
+// GRAY colored ops can be done in FP16 but don't have speedups to justify a dedicated cast.
+// RED colored ops should not be done in FP16 due to numerical reasons.
+enum FP16ConversionCategory { RED, GRAY, GREEN };

Review comment:
       I've implemented the suggestions listed.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] anijain2305 commented on pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
anijain2305 commented on pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#issuecomment-857315667


   TF SSD is good enough. Thanks @AndrewZhaoLuo 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] echuraev commented on pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
echuraev commented on pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#issuecomment-860009739


   > So you've tested only on LLVM? Does this work on `metal` target? Not sure if our metal backend supports fp16 or if M1 GPU is good at fp16 in general @echuraev
   
   The Metal backend support fp16. And as far as I know @elvin-n have run fp16 models with our Metal backend and collected some performance metrics. I think he'll add some information about it. 
   
   What about M1, we didn't try to run fp16 models on Metal on M1 yet. Theoretically, it should work, but we should check it.
   
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on a change in pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r652035337



##########
File path: src/relay/transforms/fp32_to_fp16.h
##########
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file fp32_to_fp16.h
+ * \brief Utilities and common types used for FP32->FP16 pass.
+ */
+#ifndef TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_
+#define TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_
+
+#include <tvm/ir/op.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/function.h>
+
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relay {
+
+struct FP16OpDType {
+  DataType accumulation_dtype;
+  DataType output_dtype;
+};
+
+// GREEN colored ops should always be done in FP16 due to the speed and memory savings
+// GRAY colored ops can be done in FP16 but don't have speedups to justify a dedicated cast.
+// RED colored ops should not be done in FP16 due to numerical reasons.
+enum FP16ConversionCategory { RED, GRAY, GREEN };
+
+using OpStringSet = std::unordered_set<std::string>;
+
+// Default lists inspired from TF's classifications:

Review comment:
       This is now done.

##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,356 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed precision for relay graphs. i.e. turn a graph into fp16 form.
+ */
+#include "to_mixed_precision.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// A function which maps CallNodes to their initial conversion color
+using ColorFunc = std::function<MixedTypeConversionCategory(const CallNode*)>;
+
+// A function which maps MIXED_PRECISION_ALWAYS CallNodes to wanted accumulation and output dtypes
+using OutputDtypeFunc = std::function<MixedPrecisionOpOutDType(const CallNode*)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+  const ColorFunc colorer;
+  const OutputDtypeFunc output_dtype_func;
+  const DataType mixed_precision_type;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs new_attrs = Attrs(call->attrs);
+    if (new_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = new_attrs.as<Conv1DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv1DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DeformableConv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DenseAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<BatchMatmulAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = new_attrs.as<InitOpAttrs>()) {
+        ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return new_attrs;
+  }
+
+  template <typename T>
+  void ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    T* mutable_attrs = const_cast<T*>(attrs);

Review comment:
       Done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Lunderberg commented on pull request #8069: [Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass

Posted by GitBox <gi...@apache.org>.
Lunderberg commented on pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#issuecomment-864221346


   > Later I can test it on CUDA (tensorcore) and OpenCL (intel), and hopefully @Lunderberg for vulkan.
   
   Currently, I can run all the tests in `test_to_mixed_precision.py` with the LLVM target/device, but both cuda and vulkan backends throw an exception at `TVMFuncCall` in `c_runtime_api.cc` if I edit the `run_module` function to use a different target.
   
   On the cuda side, it's failing a check that requires 16-bit floats to be used in pairs.
   
   ```
   Check failed: lanes % 2 == 0 (1 vs. 0) : only support even lane for half type
   ```
   
   On the vulkan side, it's something similar with the validation checks failing an alignment rule.
   
   ```
   Check failed: res == SPV_SUCCESS (-10 vs. 0) :  index=27 error:Structure id 12 decorated as Block for variable in StorageBuffer storage class must follow standard storage buffer layout rules: member 0 contains an array with stride 6 not satisfying alignment to 8
   %_struct_12 = OpTypeStruct %_runtimearr_v3half  
   ```
   
   I don't think either of these are reasons not to merge, and I've added the vulkan errors to my todo list for the ongoing `float16` work.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo edited a comment on pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo edited a comment on pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#issuecomment-858045479


   Hey folks, covered the simple changes requested, here is the list of more involved changes along with the associated reviewer. Several of these changes were planned to be future PRs but it might be best to just commit this correctly the first time (since it doesn't really touch other files):
   - [ ] Support other floating point types out of the box (e.g. bfloat16)
   - [ ] Naming of things (pass, GREEN/RED/GRAY, etc.)
   - [ ] Python interface for Coloring/Accumulation logic
   - [ ] How to register ops for coloring
   - [x] MixedModeMutator to avoid stackoverflow
   
   Let me know if I missed anything


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] comaniac commented on a change in pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r649466687



##########
File path: python/tvm/relay/transform/transform.py
##########
@@ -1199,3 +1198,20 @@ def FakeQuantizationToInteger():
         The registered SimplifyExpr pass.
     """
     return _ffi_api.FakeQuantizationToInteger()
+
+
+def AMPRewrite():

Review comment:
       I'm fine if no others complain about this naming.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo edited a comment on pull request #8069: [Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo edited a comment on pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#issuecomment-865193717


   Added some tracking issues for CUDA and Vulkan:
   https://github.com/apache/tvm/issues/8295 
   https://github.com/apache/tvm/issues/8294
   
   + main tracking issue: https://github.com/apache/tvm/issues/8296


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] anijain2305 commented on a change in pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
anijain2305 commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r646843324



##########
File path: src/relay/transforms/fp32_to_fp16.h
##########
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file fp32_to_fp16.h
+ * \brief Utilities and common types used for FP32->FP16 pass.
+ */
+#ifndef TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_
+#define TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_
+
+#include <tvm/ir/op.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/function.h>
+
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relay {
+
+struct FP16OpDType {
+  DataType accumulation_dtype;
+  DataType output_dtype;
+};
+
+// GREEN colored ops should always be done in FP16 due to the speed and memory savings
+// GRAY colored ops can be done in FP16 but don't have speedups to justify a dedicated cast.
+// RED colored ops should not be done in FP16 due to numerical reasons.
+enum FP16ConversionCategory { RED, GRAY, GREEN };

Review comment:
       What happens if there is an op that is not associated with any of the colors? Is the default RED?

##########
File path: src/relay/transforms/fp32_to_fp16.h
##########
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file fp32_to_fp16.h
+ * \brief Utilities and common types used for FP32->FP16 pass.
+ */
+#ifndef TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_
+#define TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_
+
+#include <tvm/ir/op.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/function.h>
+
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relay {
+
+struct FP16OpDType {
+  DataType accumulation_dtype;
+  DataType output_dtype;
+};
+
+// GREEN colored ops should always be done in FP16 due to the speed and memory savings
+// GRAY colored ops can be done in FP16 but don't have speedups to justify a dedicated cast.
+// RED colored ops should not be done in FP16 due to numerical reasons.
+enum FP16ConversionCategory { RED, GRAY, GREEN };
+
+using OpStringSet = std::unordered_set<std::string>;
+
+// Default lists inspired from TF's classifications:
+// github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h
+// They have a bias toward Nvidia Tensor Cores so modify lists per your hardware choice.
+OpStringSet DEFAULT_GREEN_LIST({
+    "nn.conv1d",
+    "nn.conv2d",
+    "nn.conv3d",
+    "nn.conv1d_transpose",
+    "nn.conv2d_transpose",
+    "nn.conv3d_transpose",
+    "nn.dense",
+    "nn.batch_matmul",
+});
+OpStringSet DEFAULT_GRAY_LIST({
+    // These ops add new data or change shape
+    "nn.pad",
+    "nn.batch_flatten",
+    "concatenate",
+    "zeros",
+    "split",
+    "squeeze",
+    "transpose",
+    "expand_dims",
+    "reshape",
+    "dyn.reshape",
+    "broadcast_to_like",
+    "dyn.broadcast_to",
+    "strided_slice",
+    "dyn.strided_slice",
+    "take",
+    "argwhere",
+    "where",
+    "tile",
+    "dyn.tile",
+    "scatter",
+    "full",
+    "dyn.full",
+    // Comparison
+    "less",
+    "greater",
+    "less_equal",
+    "greater_equal",
+    // By definition copy and cast will become green or red based on inputs
+    "copy",
+    "cast",
+    "cast_like",
+    // Simple arithmetic
+    "add",
+    "subtract",
+    "multiply",
+    "divide",
+    "nn.bias_add",
+    "nn.batch_norm",
+    "sum",
+    "mean",
+    "sqrt",
+    "shape_of",
+    // Simple activations
+    "max",
+    "min",
+    "maximum",
+    "minimum",
+    "nn.relu",
+    "nn.leaky_relu",
+    "nn.prelu",
+    "nn.dropout",
+    // Complicated activations which saturate in a narrow range
+    "sigmoid",
+    "tanh",
+    // Pooling operations
+    "nn.max_pool1d",
+    "nn.max_pool2d",
+    "nn.max_pool3d",
+    "nn.avg_pool1d",
+    "nn.avg_pool2d",
+    "nn.avg_pool3d",
+    // "nn.global_max_pool1d", // does not exist yet
+    "nn.global_max_pool2d",
+    // "nn.global_max_pool3d", // does not exist yet
+    // "nn.global_avg_pool1d", // does not exist yet
+    "nn.global_avg_pool2d",
+    // "nn.global_avg_pool3d", // does not exist yet
+    "nn.adaptive_max_pool1d",
+    "nn.adaptive_max_pool2d",
+    "nn.adaptive_max_pool3d",
+    "nn.adaptive_avg_pool1d",
+    "nn.adaptive_avg_pool2d",
+    "nn.adaptive_avg_pool3d",
+});
+OpStringSet DEFAULT_RED_LIST({
+    // In general if |f(x)| >> |x| for expected inputs then put the op here.
+    "exp",
+    "power",
+    "nn.cross_entropy",
+    "nn.cross_entropy_with_logits",
+    "nn.softmax",
+    "nn.l2_normalize",
+    // Error function doesn't seem to be able to be lowered into fp16 version in llvm.
+    // Move to gray list when it does.
+    "erf",
+});
+
+class DefaultFP16Colorer {
+  /* The default class to initially color ops for conversion using lists.
+
+  Creates a callable which given a CallNode* returns the node's color.
+  */
+ private:
+  std::unordered_map<std::string, FP16ConversionCategory> op_to_initial_color;
+
+ public:
+  DefaultFP16Colorer(OpStringSet red_list = DEFAULT_RED_LIST,
+                     OpStringSet gray_list = DEFAULT_GRAY_LIST,
+                     OpStringSet green_list = DEFAULT_GREEN_LIST) {
+    std::vector<std::pair<OpStringSet, FP16ConversionCategory>> lists_and_colors{
+        {red_list, RED}, {gray_list, GRAY}, {green_list, GREEN}};
+
+    for (auto list_and_color : lists_and_colors) {
+      OpStringSet ops = list_and_color.first;
+      FP16ConversionCategory color = list_and_color.second;
+      for (std::string op_name : ops) {
+        op_to_initial_color.insert({{op_name, color}});
+      }
+    }
+  }
+
+  FP16ConversionCategory operator()(const CallNode* call, bool ignore_missing = true) {
+    if (auto* op_node = (call->op).as<tvm::OpNode>()) {
+      std::string op_name = op_node->name;
+      auto color = op_to_initial_color.find(op_name);
+
+      if (color == op_to_initial_color.end()) {
+        if (ignore_missing) {
+          LOG(WARNING) << "Op name " << op_name << " not in included in fp16 conversion lists!.";
+          return RED;
+        } else {
+          LOG(FATAL) << "Op name " << op_name << " not in included in fp16 lists!.";

Review comment:
       Remove the period at the end.

##########
File path: src/relay/transforms/fp32_to_fp16.h
##########
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file fp32_to_fp16.h
+ * \brief Utilities and common types used for FP32->FP16 pass.
+ */
+#ifndef TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_
+#define TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_
+
+#include <tvm/ir/op.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/function.h>
+
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relay {
+
+struct FP16OpDType {
+  DataType accumulation_dtype;
+  DataType output_dtype;
+};
+
+// GREEN colored ops should always be done in FP16 due to the speed and memory savings
+// GRAY colored ops can be done in FP16 but don't have speedups to justify a dedicated cast.
+// RED colored ops should not be done in FP16 due to numerical reasons.
+enum FP16ConversionCategory { RED, GRAY, GREEN };
+
+using OpStringSet = std::unordered_set<std::string>;
+
+// Default lists inspired from TF's classifications:
+// github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h
+// They have a bias toward Nvidia Tensor Cores so modify lists per your hardware choice.
+OpStringSet DEFAULT_GREEN_LIST({
+    "nn.conv1d",
+    "nn.conv2d",
+    "nn.conv3d",
+    "nn.conv1d_transpose",
+    "nn.conv2d_transpose",
+    "nn.conv3d_transpose",
+    "nn.dense",
+    "nn.batch_matmul",
+});
+OpStringSet DEFAULT_GRAY_LIST({
+    // These ops add new data or change shape
+    "nn.pad",
+    "nn.batch_flatten",
+    "concatenate",
+    "zeros",
+    "split",
+    "squeeze",
+    "transpose",
+    "expand_dims",
+    "reshape",
+    "dyn.reshape",
+    "broadcast_to_like",
+    "dyn.broadcast_to",
+    "strided_slice",
+    "dyn.strided_slice",
+    "take",
+    "argwhere",
+    "where",
+    "tile",
+    "dyn.tile",
+    "scatter",
+    "full",
+    "dyn.full",
+    // Comparison
+    "less",
+    "greater",
+    "less_equal",
+    "greater_equal",
+    // By definition copy and cast will become green or red based on inputs
+    "copy",
+    "cast",
+    "cast_like",
+    // Simple arithmetic
+    "add",
+    "subtract",
+    "multiply",
+    "divide",
+    "nn.bias_add",
+    "nn.batch_norm",
+    "sum",
+    "mean",
+    "sqrt",
+    "shape_of",
+    // Simple activations
+    "max",
+    "min",
+    "maximum",
+    "minimum",
+    "nn.relu",
+    "nn.leaky_relu",
+    "nn.prelu",
+    "nn.dropout",
+    // Complicated activations which saturate in a narrow range
+    "sigmoid",
+    "tanh",
+    // Pooling operations
+    "nn.max_pool1d",
+    "nn.max_pool2d",
+    "nn.max_pool3d",
+    "nn.avg_pool1d",
+    "nn.avg_pool2d",
+    "nn.avg_pool3d",
+    // "nn.global_max_pool1d", // does not exist yet
+    "nn.global_max_pool2d",
+    // "nn.global_max_pool3d", // does not exist yet
+    // "nn.global_avg_pool1d", // does not exist yet
+    "nn.global_avg_pool2d",
+    // "nn.global_avg_pool3d", // does not exist yet
+    "nn.adaptive_max_pool1d",
+    "nn.adaptive_max_pool2d",
+    "nn.adaptive_max_pool3d",
+    "nn.adaptive_avg_pool1d",
+    "nn.adaptive_avg_pool2d",
+    "nn.adaptive_avg_pool3d",
+});
+OpStringSet DEFAULT_RED_LIST({
+    // In general if |f(x)| >> |x| for expected inputs then put the op here.
+    "exp",
+    "power",
+    "nn.cross_entropy",
+    "nn.cross_entropy_with_logits",
+    "nn.softmax",
+    "nn.l2_normalize",
+    // Error function doesn't seem to be able to be lowered into fp16 version in llvm.
+    // Move to gray list when it does.
+    "erf",
+});
+
+class DefaultFP16Colorer {
+  /* The default class to initially color ops for conversion using lists.
+
+  Creates a callable which given a CallNode* returns the node's color.
+  */
+ private:
+  std::unordered_map<std::string, FP16ConversionCategory> op_to_initial_color;
+
+ public:
+  DefaultFP16Colorer(OpStringSet red_list = DEFAULT_RED_LIST,
+                     OpStringSet gray_list = DEFAULT_GRAY_LIST,
+                     OpStringSet green_list = DEFAULT_GREEN_LIST) {
+    std::vector<std::pair<OpStringSet, FP16ConversionCategory>> lists_and_colors{
+        {red_list, RED}, {gray_list, GRAY}, {green_list, GREEN}};
+
+    for (auto list_and_color : lists_and_colors) {
+      OpStringSet ops = list_and_color.first;
+      FP16ConversionCategory color = list_and_color.second;
+      for (std::string op_name : ops) {
+        op_to_initial_color.insert({{op_name, color}});
+      }
+    }
+  }
+
+  FP16ConversionCategory operator()(const CallNode* call, bool ignore_missing = true) {
+    if (auto* op_node = (call->op).as<tvm::OpNode>()) {
+      std::string op_name = op_node->name;
+      auto color = op_to_initial_color.find(op_name);
+
+      if (color == op_to_initial_color.end()) {
+        if (ignore_missing) {
+          LOG(WARNING) << "Op name " << op_name << " not in included in fp16 conversion lists!.";

Review comment:
       Remove the period at the end.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] comaniac commented on a change in pull request #8069: [Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r652190924



##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,409 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed floating point precision for relay graphs. i.e. turn a graph into fp16.
+ *
+ */
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
+// savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to
+// justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to
+// numerical reasons.
+enum MixedTypeConversionCategory : int {
+  MIXED_PRECISION_ALWAYS = 0,
+  MIXED_PRECISION_FOLLOW = 1,
+  MIXED_PRECISION_NEVER = 2
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// Return array is of type : [MixedTypeConversionCategory (int), String, String]
+// The fields are          : [ConversionCategory, accumulation_datatype, output_datatype]
+// Call is a call node, DataType is the mixed precision type
+using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc<Array<ObjectRef>(
+    const Call& call_node, const std::string& target_dtype_str)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+
+  // The target datatype we want to convert to e.g. FP16
+  const DataType mixed_precision_type;
+
+  // If false, throws a fatal error if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool ignore_missing_ops;
+
+  // If true, emits a warning if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool warn_missing_ops;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs cur_attrs = call->attrs;
+    if (cur_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = cur_attrs.as<Conv1DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv1DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DeformableConv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DenseAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<BatchMatmulAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = cur_attrs.as<InitOpAttrs>()) {
+        return ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return cur_attrs;
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    DataType cur_type = (attrs->out_dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->out_dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    DataType cur_type = (attrs->dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only target mixed precision type elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == mixed_precision_type;
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsMixedPrecisionType(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    if (expr_dtype == wanted_dtype) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;
+
+    // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
+    const ExprNode* new_expr_node = result.as<ExprNode>();
+    cast_nodes_cache[{new_expr_node, expr_dtype}] = expr;
+    return result;
+  }
+
+  Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
+    /* Helper for casting arguments to call_nodes handling all relevant cases. */
+    if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+      return CachedCast(expr, tensor_type->dtype, wanted_dtype);
+    } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
+        Expr tuple_element = GetField(expr, i);
+        Type tuple_element_dtype = (tuple_type->fields)[i];
+        Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(tuple_element);
+      }
+      return all_same ? expr : Tuple(new_expr);
+    } else {
+      LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
+      return expr;
+    }
+  }
+
+  std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args,
+                                                  const Array<Type>& cur_arg_types,
+                                                  const DataType& wanted_dtype) {
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    for (size_t i = 0; i < cur_args.size(); i++) {
+      Expr cur_arg = cur_args[i];
+      Type cur_arg_type = cur_arg_types[i];
+      Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+    }
+    return {new_args, new_arg_types};
+  }
+
+ public:
+  using MixedModeMutator::VisitExpr_;
+
+  explicit MixedPrecisionPass(DataType mixed_precision_type = DataType::Float(16),
+                              bool ignore_missing_ops = true, bool warn_missing_ops = true)
+      : MixedModeMutator(),
+        mixed_precision_type(mixed_precision_type),
+        ignore_missing_ops(ignore_missing_ops),
+        warn_missing_ops(warn_missing_ops) {
+    if (!mixed_precision_type.is_float() && !mixed_precision_type.is_bfloat16())
+      LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat16 got "
+                 << mixed_precision_type;
+  }
+
+  Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final {
+    const CallNode* post_call_node = post.as<CallNode>();
+    if (!post_call_node) {
+      LOG(FATAL) << "Expected a CallNode for the rewrite got " << post;
+    }
+
+    Expr cur_op = post_call_node->op;
+
+    // Get info on the operation being called:
+    // conversion category (int), accumulation dtype (str), output dtype (str)
+    MixedTypeConversionCategory initial_category;
+    DataType accumulation_dtype, output_dtype;
+    if (cur_op.as<FunctionNode>()) {
+      // Avoid messing with functions to avoid changing signature
+      initial_category = MIXED_PRECISION_NEVER;
+      accumulation_dtype = DataType::Float(32);
+      output_dtype = DataType::Float(32);
+    } else if (cur_op.as<OpNode>()) {
+      static auto attr_map =
+          Op::GetAttrMap<FTVMMixedPrecisionConversionType>("FTVMMixedPrecisionConversionType");
+      Op op = Downcast<Op>(cur_op);
+      if (attr_map.count(op)) {
+        // Calculate the conversion category and dtypes from registered attribute.
+        FTVMMixedPrecisionConversionType func = attr_map[op];
+        Array<ObjectRef> op_descriptor =
+            func(GetRef<Call>(pre_call_node), DLDataType2String(mixed_precision_type));

Review comment:
       Downcast does the check, but the error message would be confusing. This is a miner point tho.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] CoinCheung commented on pull request #8069: [Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass

Posted by GitBox <gi...@apache.org>.
CoinCheung commented on pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#issuecomment-878683239


   Hi, 
   
   I am not sure whether it is a usage question or the code can be refined, I am using a quite new commit pulled from github:
   
   ![图片](https://user-images.githubusercontent.com/22693362/125371003-669a3c80-e3b2-11eb-9dc8-ea7526d22e70.png)
   
   And I built from the source following the steps in the doc website. Only changes in the `config.cmake` are as follows:
   ```
   set(USE_LLVM ON)
   set(USE_CUDA ON)
   set(USE_CUDNN ON)
   ```
   Should I still go to discussion website for help?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo edited a comment on pull request #8069: [Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo edited a comment on pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#issuecomment-865193717


   Added some tracking issues for CUDA and Vulkan:
   https://github.com/apache/tvm/issues/8295 
   https://github.com/apache/tvm/issues/8294
   
   + main tracking issue: https://github.com/apache/tvm/issues/8296


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] comaniac commented on a change in pull request #8069: [Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r652169800



##########
File path: python/tvm/relay/transform/transform.py
##########
@@ -1199,3 +1198,18 @@ def FakeQuantizationToInteger():
         The registered SimplifyExpr pass.
     """
     return _ffi_api.FakeQuantizationToInteger()
+
+
+def ToMixedPrecision(
+    mixed_precision_type="float16", ignore_missing_ops=True, warn_missing_ops=True
+):
+    """
+    Automatic mixed precision rewriter. Rewrite an FP32 relay graph into a version
+    where as many operations as possible are in the target mixed_precision_type.
+
+    Returns
+    -------
+    ret : tvm.transform.Pass
+        The registered RewriteFP16 pass.

Review comment:
       ```suggestion
           The registered pass.
   ```

##########
File path: python/tvm/relay/op/op.py
##########
@@ -457,6 +458,29 @@ def register_fake_quantization_to_integer(op_name, func=None, level=10):
     return tvm.ir.register_op_attr(op_name, "FTVMFakeQuantizationToInteger", func, level)
 
 
+def register_mixed_precision_conversion(op_name, func=None, level=10):
+    """Register mixed precision conversion function for an op
+
+    Given an op the function should return information on how the value should be
+    converted. Specifically the function should take a call node and the target
+    mixed precision datatype (e.g. FP16) and return the conversion category
+    (see python/tvm/relay/transform/mixed_precision.py) as well as the accumulation
+    and output datatype of the oepration.

Review comment:
       ```suggestion
       and output datatype of the operation.
   ```

##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,409 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed floating point precision for relay graphs. i.e. turn a graph into fp16.
+ *
+ */
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
+// savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to
+// justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to
+// numerical reasons.
+enum MixedTypeConversionCategory : int {
+  MIXED_PRECISION_ALWAYS = 0,
+  MIXED_PRECISION_FOLLOW = 1,
+  MIXED_PRECISION_NEVER = 2
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// Return array is of type : [MixedTypeConversionCategory (int), String, String]
+// The fields are          : [ConversionCategory, accumulation_datatype, output_datatype]
+// Call is a call node, DataType is the mixed precision type
+using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc<Array<ObjectRef>(
+    const Call& call_node, const std::string& target_dtype_str)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;

Review comment:
       Add suffix `_` to all private class members.
   ```suggestion
     CachedCastNodes cast_nodes_cache_;
   ```

##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,409 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed floating point precision for relay graphs. i.e. turn a graph into fp16.
+ *
+ */
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
+// savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to
+// justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to
+// numerical reasons.
+enum MixedTypeConversionCategory : int {
+  MIXED_PRECISION_ALWAYS = 0,
+  MIXED_PRECISION_FOLLOW = 1,
+  MIXED_PRECISION_NEVER = 2
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// Return array is of type : [MixedTypeConversionCategory (int), String, String]
+// The fields are          : [ConversionCategory, accumulation_datatype, output_datatype]
+// Call is a call node, DataType is the mixed precision type
+using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc<Array<ObjectRef>(
+    const Call& call_node, const std::string& target_dtype_str)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+
+  // The target datatype we want to convert to e.g. FP16
+  const DataType mixed_precision_type;
+
+  // If false, throws a fatal error if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool ignore_missing_ops;
+
+  // If true, emits a warning if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool warn_missing_ops;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs cur_attrs = call->attrs;
+    if (cur_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = cur_attrs.as<Conv1DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv1DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DeformableConv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DenseAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<BatchMatmulAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = cur_attrs.as<InitOpAttrs>()) {
+        return ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return cur_attrs;
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    DataType cur_type = (attrs->out_dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->out_dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    DataType cur_type = (attrs->dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only target mixed precision type elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == mixed_precision_type;
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsMixedPrecisionType(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    if (expr_dtype == wanted_dtype) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;
+
+    // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
+    const ExprNode* new_expr_node = result.as<ExprNode>();
+    cast_nodes_cache[{new_expr_node, expr_dtype}] = expr;
+    return result;
+  }
+
+  Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
+    /* Helper for casting arguments to call_nodes handling all relevant cases. */
+    if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+      return CachedCast(expr, tensor_type->dtype, wanted_dtype);
+    } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
+        Expr tuple_element = GetField(expr, i);
+        Type tuple_element_dtype = (tuple_type->fields)[i];
+        Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(tuple_element);
+      }
+      return all_same ? expr : Tuple(new_expr);
+    } else {
+      LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
+      return expr;
+    }
+  }
+
+  std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args,
+                                                  const Array<Type>& cur_arg_types,
+                                                  const DataType& wanted_dtype) {
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    for (size_t i = 0; i < cur_args.size(); i++) {
+      Expr cur_arg = cur_args[i];
+      Type cur_arg_type = cur_arg_types[i];
+      Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+    }
+    return {new_args, new_arg_types};
+  }
+
+ public:
+  using MixedModeMutator::VisitExpr_;
+
+  explicit MixedPrecisionPass(DataType mixed_precision_type = DataType::Float(16),
+                              bool ignore_missing_ops = true, bool warn_missing_ops = true)
+      : MixedModeMutator(),
+        mixed_precision_type(mixed_precision_type),
+        ignore_missing_ops(ignore_missing_ops),
+        warn_missing_ops(warn_missing_ops) {
+    if (!mixed_precision_type.is_float() && !mixed_precision_type.is_bfloat16())
+      LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat16 got "
+                 << mixed_precision_type;
+  }
+
+  Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final {
+    const CallNode* post_call_node = post.as<CallNode>();
+    if (!post_call_node) {
+      LOG(FATAL) << "Expected a CallNode for the rewrite got " << post;
+    }
+
+    Expr cur_op = post_call_node->op;
+
+    // Get info on the operation being called:
+    // conversion category (int), accumulation dtype (str), output dtype (str)
+    MixedTypeConversionCategory initial_category;
+    DataType accumulation_dtype, output_dtype;
+    if (cur_op.as<FunctionNode>()) {
+      // Avoid messing with functions to avoid changing signature
+      initial_category = MIXED_PRECISION_NEVER;
+      accumulation_dtype = DataType::Float(32);
+      output_dtype = DataType::Float(32);
+    } else if (cur_op.as<OpNode>()) {
+      static auto attr_map =
+          Op::GetAttrMap<FTVMMixedPrecisionConversionType>("FTVMMixedPrecisionConversionType");
+      Op op = Downcast<Op>(cur_op);
+      if (attr_map.count(op)) {
+        // Calculate the conversion category and dtypes from registered attribute.
+        FTVMMixedPrecisionConversionType func = attr_map[op];
+        Array<ObjectRef> op_descriptor =
+            func(GetRef<Call>(pre_call_node), DLDataType2String(mixed_precision_type));
+
+        int64_t op_conversion_type = Downcast<Integer>(op_descriptor[0])->value;
+        initial_category = static_cast<MixedTypeConversionCategory>(op_conversion_type);
+        accumulation_dtype = DataType(String2DLDataType(Downcast<String>(op_descriptor[1])));
+        output_dtype = DataType(String2DLDataType(Downcast<String>(op_descriptor[2])));
+      } else {
+        if (!ignore_missing_ops) LOG(FATAL) << "Op " << op->name << " not in conversion lists!";
+        if (warn_missing_ops) LOG(WARNING) << "Op " << op->name << " not in conversion lists!";
+
+        // If not registered, by default assume is a generic FOLLOW operation.
+        initial_category = MIXED_PRECISION_FOLLOW;
+        accumulation_dtype = DataType::Float(16);
+        output_dtype = DataType::Float(16);
+      }
+    } else {
+      LOG(FATAL) << "Unsupported op type in CallNode: " << pre_call_node->op;
+    }
+
+    // First check if all the new mutated args are in lower precision form
+    Array<Type> cur_arg_types;
+    bool all_args_mixed_type_compatible = true;
+    for (Expr arg : post_call_node->args) {
+      Type cur_arg_type = GetType(arg);
+      cur_arg_types.push_back(cur_arg_type);
+
+      if (initial_category == MIXED_PRECISION_FOLLOW && all_args_mixed_type_compatible) {
+        // We can cast Vars and Constants to the right types so don't care about the types.
+        bool is_mixed_type_compatible = IsMixedPrecisionType(cur_arg_type, true) ||
+                                        arg->IsInstance<VarNode>() ||
+                                        arg->IsInstance<ConstantNode>();
+        all_args_mixed_type_compatible &= is_mixed_type_compatible;
+      }
+    }
+
+    // Determine the final category we want for conversion
+    MixedTypeConversionCategory final_category;
+    if (initial_category == MIXED_PRECISION_FOLLOW) {
+      final_category =
+          all_args_mixed_type_compatible ? MIXED_PRECISION_ALWAYS : MIXED_PRECISION_NEVER;
+    } else {
+      final_category = initial_category;
+    }

Review comment:
       nit
   ```suggestion
       MixedTypeConversionCategory final_category = initial_category;
       if (initial_category == MIXED_PRECISION_FOLLOW) {
         final_category =
             all_args_mixed_type_compatible ? MIXED_PRECISION_ALWAYS : MIXED_PRECISION_NEVER;
       }
   ```

##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,409 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed floating point precision for relay graphs. i.e. turn a graph into fp16.
+ *
+ */
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
+// savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to
+// justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to
+// numerical reasons.
+enum MixedTypeConversionCategory : int {
+  MIXED_PRECISION_ALWAYS = 0,
+  MIXED_PRECISION_FOLLOW = 1,
+  MIXED_PRECISION_NEVER = 2
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// Return array is of type : [MixedTypeConversionCategory (int), String, String]
+// The fields are          : [ConversionCategory, accumulation_datatype, output_datatype]
+// Call is a call node, DataType is the mixed precision type
+using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc<Array<ObjectRef>(
+    const Call& call_node, const std::string& target_dtype_str)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+
+  // The target datatype we want to convert to e.g. FP16
+  const DataType mixed_precision_type;
+
+  // If false, throws a fatal error if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool ignore_missing_ops;
+
+  // If true, emits a warning if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool warn_missing_ops;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs cur_attrs = call->attrs;
+    if (cur_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = cur_attrs.as<Conv1DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv1DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DeformableConv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DenseAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<BatchMatmulAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = cur_attrs.as<InitOpAttrs>()) {
+        return ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return cur_attrs;
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    DataType cur_type = (attrs->out_dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->out_dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    DataType cur_type = (attrs->dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only target mixed precision type elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == mixed_precision_type;
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsMixedPrecisionType(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    if (expr_dtype == wanted_dtype) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }

Review comment:
       ```suggestion
       CHECK(expr_node) << "Non-expression node found in cast: " << expr;
   ```

##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,409 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed floating point precision for relay graphs. i.e. turn a graph into fp16.
+ *
+ */
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
+// savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to
+// justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to
+// numerical reasons.
+enum MixedTypeConversionCategory : int {
+  MIXED_PRECISION_ALWAYS = 0,
+  MIXED_PRECISION_FOLLOW = 1,
+  MIXED_PRECISION_NEVER = 2
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// Return array is of type : [MixedTypeConversionCategory (int), String, String]
+// The fields are          : [ConversionCategory, accumulation_datatype, output_datatype]
+// Call is a call node, DataType is the mixed precision type
+using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc<Array<ObjectRef>(
+    const Call& call_node, const std::string& target_dtype_str)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+
+  // The target datatype we want to convert to e.g. FP16
+  const DataType mixed_precision_type;
+
+  // If false, throws a fatal error if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool ignore_missing_ops;
+
+  // If true, emits a warning if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool warn_missing_ops;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs cur_attrs = call->attrs;
+    if (cur_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = cur_attrs.as<Conv1DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv1DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DeformableConv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DenseAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<BatchMatmulAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = cur_attrs.as<InitOpAttrs>()) {
+        return ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return cur_attrs;
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    DataType cur_type = (attrs->out_dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->out_dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    DataType cur_type = (attrs->dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only target mixed precision type elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == mixed_precision_type;
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsMixedPrecisionType(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    if (expr_dtype == wanted_dtype) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;
+
+    // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
+    const ExprNode* new_expr_node = result.as<ExprNode>();
+    cast_nodes_cache[{new_expr_node, expr_dtype}] = expr;
+    return result;
+  }
+
+  Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
+    /* Helper for casting arguments to call_nodes handling all relevant cases. */
+    if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+      return CachedCast(expr, tensor_type->dtype, wanted_dtype);
+    } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
+        Expr tuple_element = GetField(expr, i);
+        Type tuple_element_dtype = (tuple_type->fields)[i];
+        Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(tuple_element);
+      }
+      return all_same ? expr : Tuple(new_expr);
+    } else {
+      LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
+      return expr;
+    }
+  }
+
+  std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args,
+                                                  const Array<Type>& cur_arg_types,
+                                                  const DataType& wanted_dtype) {
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    for (size_t i = 0; i < cur_args.size(); i++) {
+      Expr cur_arg = cur_args[i];
+      Type cur_arg_type = cur_arg_types[i];
+      Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+    }
+    return {new_args, new_arg_types};
+  }
+
+ public:
+  using MixedModeMutator::VisitExpr_;
+
+  explicit MixedPrecisionPass(DataType mixed_precision_type = DataType::Float(16),
+                              bool ignore_missing_ops = true, bool warn_missing_ops = true)
+      : MixedModeMutator(),
+        mixed_precision_type(mixed_precision_type),
+        ignore_missing_ops(ignore_missing_ops),
+        warn_missing_ops(warn_missing_ops) {
+    if (!mixed_precision_type.is_float() && !mixed_precision_type.is_bfloat16())
+      LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat16 got "
+                 << mixed_precision_type;
+  }
+
+  Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final {
+    const CallNode* post_call_node = post.as<CallNode>();
+    if (!post_call_node) {
+      LOG(FATAL) << "Expected a CallNode for the rewrite got " << post;
+    }

Review comment:
       ```suggestion
       CHECK(post_call_node) << "Expected a CallNode, but got " << post;
   ```

##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,409 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed floating point precision for relay graphs. i.e. turn a graph into fp16.
+ *
+ */
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
+// savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to
+// justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to
+// numerical reasons.
+enum MixedTypeConversionCategory : int {
+  MIXED_PRECISION_ALWAYS = 0,
+  MIXED_PRECISION_FOLLOW = 1,
+  MIXED_PRECISION_NEVER = 2
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// Return array is of type : [MixedTypeConversionCategory (int), String, String]
+// The fields are          : [ConversionCategory, accumulation_datatype, output_datatype]
+// Call is a call node, DataType is the mixed precision type
+using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc<Array<ObjectRef>(
+    const Call& call_node, const std::string& target_dtype_str)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+
+  // The target datatype we want to convert to e.g. FP16

Review comment:
       Please refer to other passes and use standard docstring format.
   ```suggestion
     /*! \brief The target datatype we want to convert to e.g. FP16 */
   ```

##########
File path: python/tvm/relay/transform/transform.py
##########
@@ -1199,3 +1198,18 @@ def FakeQuantizationToInteger():
         The registered SimplifyExpr pass.
     """
     return _ffi_api.FakeQuantizationToInteger()
+
+
+def ToMixedPrecision(
+    mixed_precision_type="float16", ignore_missing_ops=True, warn_missing_ops=True
+):
+    """
+    Automatic mixed precision rewriter. Rewrite an FP32 relay graph into a version
+    where as many operations as possible are in the target mixed_precision_type.

Review comment:
       docstring for parameters.

##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,409 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed floating point precision for relay graphs. i.e. turn a graph into fp16.
+ *
+ */
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
+// savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to
+// justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to
+// numerical reasons.
+enum MixedTypeConversionCategory : int {
+  MIXED_PRECISION_ALWAYS = 0,
+  MIXED_PRECISION_FOLLOW = 1,
+  MIXED_PRECISION_NEVER = 2
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// Return array is of type : [MixedTypeConversionCategory (int), String, String]
+// The fields are          : [ConversionCategory, accumulation_datatype, output_datatype]
+// Call is a call node, DataType is the mixed precision type
+using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc<Array<ObjectRef>(
+    const Call& call_node, const std::string& target_dtype_str)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+
+  // The target datatype we want to convert to e.g. FP16
+  const DataType mixed_precision_type;
+
+  // If false, throws a fatal error if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool ignore_missing_ops;
+
+  // If true, emits a warning if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool warn_missing_ops;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs cur_attrs = call->attrs;
+    if (cur_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = cur_attrs.as<Conv1DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv1DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DeformableConv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DenseAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<BatchMatmulAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = cur_attrs.as<InitOpAttrs>()) {
+        return ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return cur_attrs;
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    DataType cur_type = (attrs->out_dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->out_dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    DataType cur_type = (attrs->dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only target mixed precision type elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == mixed_precision_type;
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsMixedPrecisionType(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    if (expr_dtype == wanted_dtype) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;
+
+    // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
+    const ExprNode* new_expr_node = result.as<ExprNode>();
+    cast_nodes_cache[{new_expr_node, expr_dtype}] = expr;
+    return result;
+  }
+
+  Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
+    /* Helper for casting arguments to call_nodes handling all relevant cases. */
+    if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+      return CachedCast(expr, tensor_type->dtype, wanted_dtype);
+    } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
+        Expr tuple_element = GetField(expr, i);
+        Type tuple_element_dtype = (tuple_type->fields)[i];
+        Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(tuple_element);
+      }
+      return all_same ? expr : Tuple(new_expr);
+    } else {
+      LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
+      return expr;
+    }
+  }
+
+  std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args,
+                                                  const Array<Type>& cur_arg_types,
+                                                  const DataType& wanted_dtype) {
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    for (size_t i = 0; i < cur_args.size(); i++) {
+      Expr cur_arg = cur_args[i];
+      Type cur_arg_type = cur_arg_types[i];
+      Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+    }
+    return {new_args, new_arg_types};
+  }
+
+ public:
+  using MixedModeMutator::VisitExpr_;
+
+  explicit MixedPrecisionPass(DataType mixed_precision_type = DataType::Float(16),
+                              bool ignore_missing_ops = true, bool warn_missing_ops = true)
+      : MixedModeMutator(),
+        mixed_precision_type(mixed_precision_type),
+        ignore_missing_ops(ignore_missing_ops),
+        warn_missing_ops(warn_missing_ops) {
+    if (!mixed_precision_type.is_float() && !mixed_precision_type.is_bfloat16())
+      LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat16 got "
+                 << mixed_precision_type;
+  }
+
+  Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final {
+    const CallNode* post_call_node = post.as<CallNode>();
+    if (!post_call_node) {
+      LOG(FATAL) << "Expected a CallNode for the rewrite got " << post;
+    }
+
+    Expr cur_op = post_call_node->op;
+
+    // Get info on the operation being called:
+    // conversion category (int), accumulation dtype (str), output dtype (str)
+    MixedTypeConversionCategory initial_category;
+    DataType accumulation_dtype, output_dtype;
+    if (cur_op.as<FunctionNode>()) {
+      // Avoid messing with functions to avoid changing signature
+      initial_category = MIXED_PRECISION_NEVER;
+      accumulation_dtype = DataType::Float(32);
+      output_dtype = DataType::Float(32);
+    } else if (cur_op.as<OpNode>()) {
+      static auto attr_map =
+          Op::GetAttrMap<FTVMMixedPrecisionConversionType>("FTVMMixedPrecisionConversionType");
+      Op op = Downcast<Op>(cur_op);
+      if (attr_map.count(op)) {
+        // Calculate the conversion category and dtypes from registered attribute.
+        FTVMMixedPrecisionConversionType func = attr_map[op];
+        Array<ObjectRef> op_descriptor =
+            func(GetRef<Call>(pre_call_node), DLDataType2String(mixed_precision_type));
+
+        int64_t op_conversion_type = Downcast<Integer>(op_descriptor[0])->value;
+        initial_category = static_cast<MixedTypeConversionCategory>(op_conversion_type);
+        accumulation_dtype = DataType(String2DLDataType(Downcast<String>(op_descriptor[1])));
+        output_dtype = DataType(String2DLDataType(Downcast<String>(op_descriptor[2])));
+      } else {
+        if (!ignore_missing_ops) LOG(FATAL) << "Op " << op->name << " not in conversion lists!";
+        if (warn_missing_ops) LOG(WARNING) << "Op " << op->name << " not in conversion lists!";

Review comment:
       Looks like we can merge these configs to be `verbose` or something like it.

##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,409 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed floating point precision for relay graphs. i.e. turn a graph into fp16.
+ *
+ */
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
+// savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to
+// justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to
+// numerical reasons.
+enum MixedTypeConversionCategory : int {
+  MIXED_PRECISION_ALWAYS = 0,
+  MIXED_PRECISION_FOLLOW = 1,
+  MIXED_PRECISION_NEVER = 2
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// Return array is of type : [MixedTypeConversionCategory (int), String, String]
+// The fields are          : [ConversionCategory, accumulation_datatype, output_datatype]
+// Call is a call node, DataType is the mixed precision type
+using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc<Array<ObjectRef>(
+    const Call& call_node, const std::string& target_dtype_str)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+
+  // The target datatype we want to convert to e.g. FP16
+  const DataType mixed_precision_type;
+
+  // If false, throws a fatal error if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool ignore_missing_ops;
+
+  // If true, emits a warning if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool warn_missing_ops;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs cur_attrs = call->attrs;
+    if (cur_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = cur_attrs.as<Conv1DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv1DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DeformableConv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DenseAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<BatchMatmulAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = cur_attrs.as<InitOpAttrs>()) {
+        return ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return cur_attrs;
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    DataType cur_type = (attrs->out_dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->out_dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    DataType cur_type = (attrs->dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only target mixed precision type elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == mixed_precision_type;
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsMixedPrecisionType(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    if (expr_dtype == wanted_dtype) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;
+
+    // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
+    const ExprNode* new_expr_node = result.as<ExprNode>();
+    cast_nodes_cache[{new_expr_node, expr_dtype}] = expr;
+    return result;
+  }
+
+  Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
+    /* Helper for casting arguments to call_nodes handling all relevant cases. */
+    if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+      return CachedCast(expr, tensor_type->dtype, wanted_dtype);
+    } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
+        Expr tuple_element = GetField(expr, i);
+        Type tuple_element_dtype = (tuple_type->fields)[i];
+        Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(tuple_element);
+      }
+      return all_same ? expr : Tuple(new_expr);
+    } else {
+      LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
+      return expr;
+    }

Review comment:
       If fatal, then we shouldn't return.
   ```suggestion
       }
       CHECK(0) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
   ```

##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,409 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed floating point precision for relay graphs. i.e. turn a graph into fp16.
+ *
+ */
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
+// savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to
+// justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to
+// numerical reasons.
+enum MixedTypeConversionCategory : int {
+  MIXED_PRECISION_ALWAYS = 0,
+  MIXED_PRECISION_FOLLOW = 1,
+  MIXED_PRECISION_NEVER = 2
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// Return array is of type : [MixedTypeConversionCategory (int), String, String]
+// The fields are          : [ConversionCategory, accumulation_datatype, output_datatype]
+// Call is a call node, DataType is the mixed precision type
+using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc<Array<ObjectRef>(
+    const Call& call_node, const std::string& target_dtype_str)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+
+  // The target datatype we want to convert to e.g. FP16
+  const DataType mixed_precision_type;
+
+  // If false, throws a fatal error if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool ignore_missing_ops;
+
+  // If true, emits a warning if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool warn_missing_ops;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs cur_attrs = call->attrs;
+    if (cur_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = cur_attrs.as<Conv1DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv1DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DeformableConv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DenseAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<BatchMatmulAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = cur_attrs.as<InitOpAttrs>()) {
+        return ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return cur_attrs;
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    DataType cur_type = (attrs->out_dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->out_dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    DataType cur_type = (attrs->dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only target mixed precision type elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == mixed_precision_type;
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsMixedPrecisionType(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    if (expr_dtype == wanted_dtype) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;
+
+    // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
+    const ExprNode* new_expr_node = result.as<ExprNode>();
+    cast_nodes_cache[{new_expr_node, expr_dtype}] = expr;
+    return result;
+  }
+
+  Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
+    /* Helper for casting arguments to call_nodes handling all relevant cases. */
+    if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+      return CachedCast(expr, tensor_type->dtype, wanted_dtype);
+    } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
+        Expr tuple_element = GetField(expr, i);
+        Type tuple_element_dtype = (tuple_type->fields)[i];
+        Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(tuple_element);
+      }
+      return all_same ? expr : Tuple(new_expr);
+    } else {
+      LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
+      return expr;
+    }
+  }
+
+  std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args,
+                                                  const Array<Type>& cur_arg_types,
+                                                  const DataType& wanted_dtype) {
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    for (size_t i = 0; i < cur_args.size(); i++) {
+      Expr cur_arg = cur_args[i];
+      Type cur_arg_type = cur_arg_types[i];
+      Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+    }
+    return {new_args, new_arg_types};
+  }
+
+ public:
+  using MixedModeMutator::VisitExpr_;
+
+  explicit MixedPrecisionPass(DataType mixed_precision_type = DataType::Float(16),
+                              bool ignore_missing_ops = true, bool warn_missing_ops = true)
+      : MixedModeMutator(),
+        mixed_precision_type(mixed_precision_type),
+        ignore_missing_ops(ignore_missing_ops),
+        warn_missing_ops(warn_missing_ops) {
+    if (!mixed_precision_type.is_float() && !mixed_precision_type.is_bfloat16())
+      LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat16 got "
+                 << mixed_precision_type;

Review comment:
       ```suggestion
       if (!mixed_precision_type.is_float() && !mixed_precision_type.is_bfloat16()) {
         LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat16, but got "
                    << mixed_precision_type;
       }
   ```

##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,409 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed floating point precision for relay graphs. i.e. turn a graph into fp16.
+ *
+ */
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
+// savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to
+// justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to
+// numerical reasons.
+enum MixedTypeConversionCategory : int {
+  MIXED_PRECISION_ALWAYS = 0,
+  MIXED_PRECISION_FOLLOW = 1,
+  MIXED_PRECISION_NEVER = 2
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// Return array is of type : [MixedTypeConversionCategory (int), String, String]
+// The fields are          : [ConversionCategory, accumulation_datatype, output_datatype]
+// Call is a call node, DataType is the mixed precision type
+using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc<Array<ObjectRef>(
+    const Call& call_node, const std::string& target_dtype_str)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+
+  // The target datatype we want to convert to e.g. FP16
+  const DataType mixed_precision_type;
+
+  // If false, throws a fatal error if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool ignore_missing_ops;
+
+  // If true, emits a warning if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool warn_missing_ops;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs cur_attrs = call->attrs;
+    if (cur_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = cur_attrs.as<Conv1DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv1DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DeformableConv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DenseAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<BatchMatmulAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = cur_attrs.as<InitOpAttrs>()) {
+        return ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return cur_attrs;
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    DataType cur_type = (attrs->out_dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->out_dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    DataType cur_type = (attrs->dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only target mixed precision type elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == mixed_precision_type;
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsMixedPrecisionType(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    if (expr_dtype == wanted_dtype) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;
+
+    // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
+    const ExprNode* new_expr_node = result.as<ExprNode>();
+    cast_nodes_cache[{new_expr_node, expr_dtype}] = expr;
+    return result;
+  }
+
+  Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
+    /* Helper for casting arguments to call_nodes handling all relevant cases. */
+    if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+      return CachedCast(expr, tensor_type->dtype, wanted_dtype);
+    } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
+        Expr tuple_element = GetField(expr, i);
+        Type tuple_element_dtype = (tuple_type->fields)[i];
+        Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(tuple_element);
+      }
+      return all_same ? expr : Tuple(new_expr);
+    } else {
+      LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
+      return expr;
+    }
+  }
+
+  std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args,
+                                                  const Array<Type>& cur_arg_types,
+                                                  const DataType& wanted_dtype) {
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    for (size_t i = 0; i < cur_args.size(); i++) {
+      Expr cur_arg = cur_args[i];
+      Type cur_arg_type = cur_arg_types[i];
+      Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+    }
+    return {new_args, new_arg_types};
+  }
+
+ public:
+  using MixedModeMutator::VisitExpr_;
+
+  explicit MixedPrecisionPass(DataType mixed_precision_type = DataType::Float(16),
+                              bool ignore_missing_ops = true, bool warn_missing_ops = true)
+      : MixedModeMutator(),
+        mixed_precision_type(mixed_precision_type),
+        ignore_missing_ops(ignore_missing_ops),
+        warn_missing_ops(warn_missing_ops) {
+    if (!mixed_precision_type.is_float() && !mixed_precision_type.is_bfloat16())
+      LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat16 got "
+                 << mixed_precision_type;
+  }
+
+  Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final {
+    const CallNode* post_call_node = post.as<CallNode>();
+    if (!post_call_node) {
+      LOG(FATAL) << "Expected a CallNode for the rewrite got " << post;
+    }
+
+    Expr cur_op = post_call_node->op;
+
+    // Get info on the operation being called:
+    // conversion category (int), accumulation dtype (str), output dtype (str)
+    MixedTypeConversionCategory initial_category;
+    DataType accumulation_dtype, output_dtype;
+    if (cur_op.as<FunctionNode>()) {
+      // Avoid messing with functions to avoid changing signature
+      initial_category = MIXED_PRECISION_NEVER;
+      accumulation_dtype = DataType::Float(32);
+      output_dtype = DataType::Float(32);
+    } else if (cur_op.as<OpNode>()) {
+      static auto attr_map =
+          Op::GetAttrMap<FTVMMixedPrecisionConversionType>("FTVMMixedPrecisionConversionType");
+      Op op = Downcast<Op>(cur_op);
+      if (attr_map.count(op)) {
+        // Calculate the conversion category and dtypes from registered attribute.
+        FTVMMixedPrecisionConversionType func = attr_map[op];
+        Array<ObjectRef> op_descriptor =
+            func(GetRef<Call>(pre_call_node), DLDataType2String(mixed_precision_type));
+
+        int64_t op_conversion_type = Downcast<Integer>(op_descriptor[0])->value;
+        initial_category = static_cast<MixedTypeConversionCategory>(op_conversion_type);
+        accumulation_dtype = DataType(String2DLDataType(Downcast<String>(op_descriptor[1])));
+        output_dtype = DataType(String2DLDataType(Downcast<String>(op_descriptor[2])));
+      } else {
+        if (!ignore_missing_ops) LOG(FATAL) << "Op " << op->name << " not in conversion lists!";
+        if (warn_missing_ops) LOG(WARNING) << "Op " << op->name << " not in conversion lists!";
+
+        // If not registered, by default assume is a generic FOLLOW operation.
+        initial_category = MIXED_PRECISION_FOLLOW;
+        accumulation_dtype = DataType::Float(16);
+        output_dtype = DataType::Float(16);
+      }
+    } else {
+      LOG(FATAL) << "Unsupported op type in CallNode: " << pre_call_node->op;
+    }
+
+    // First check if all the new mutated args are in lower precision form
+    Array<Type> cur_arg_types;
+    bool all_args_mixed_type_compatible = true;
+    for (Expr arg : post_call_node->args) {
+      Type cur_arg_type = GetType(arg);
+      cur_arg_types.push_back(cur_arg_type);
+
+      if (initial_category == MIXED_PRECISION_FOLLOW && all_args_mixed_type_compatible) {
+        // We can cast Vars and Constants to the right types so don't care about the types.
+        bool is_mixed_type_compatible = IsMixedPrecisionType(cur_arg_type, true) ||
+                                        arg->IsInstance<VarNode>() ||
+                                        arg->IsInstance<ConstantNode>();
+        all_args_mixed_type_compatible &= is_mixed_type_compatible;
+      }
+    }
+
+    // Determine the final category we want for conversion
+    MixedTypeConversionCategory final_category;
+    if (initial_category == MIXED_PRECISION_FOLLOW) {
+      final_category =
+          all_args_mixed_type_compatible ? MIXED_PRECISION_ALWAYS : MIXED_PRECISION_NEVER;
+    } else {
+      final_category = initial_category;
+    }
+
+    // Create the new arguments to the call.
+    DataType wanted_arg_dtypes =
+        final_category == MIXED_PRECISION_ALWAYS ? mixed_precision_type : DataType::Float(32);
+    auto call_args_and_types = CastAllArgs(post_call_node->args, cur_arg_types, wanted_arg_dtypes);
+    Array<Expr> new_args = call_args_and_types.first;
+    Array<Type> new_arg_types;
+
+    if (pre_call_node->op.as<FunctionNode>()) {
+      // Function Nodes don't store type info in the Call, it should be a []
+      new_arg_types = pre_call_node->type_args;
+    } else {
+      new_arg_types = call_args_and_types.second;
+    }
+
+    // Finally create the new attributes.
+    if (final_category == MIXED_PRECISION_ALWAYS) {
+      Attrs new_attrs = GetNewAttrs(pre_call_node, accumulation_dtype);
+      Expr output = Call(cur_op, new_args, new_attrs, new_arg_types, pre_call_node->span);
+      if (accumulation_dtype != output_dtype) {
+        output = CastArg(output, GetType(output), output_dtype);
+      }
+      return output;
+    }
+
+    return Call(cur_op, new_args, pre_call_node->attrs, new_arg_types, pre_call_node->span);
+  }
+
+  Expr VisitExpr_(const FunctionNode* func) final {
+    // Erase the ret_type annotation and let the normal pass recalculate
+    const_cast<FunctionNode*>(func)->ret_type = Type(nullptr);
+    return ExprMutator::VisitExpr_(func);
+  }
+
+  Expr VisitExpr_(const LetNode* op) final {
+    // First convert as much of the bound computation to lower precision as possible
+    Expr value = this->Mutate(op->value);
+
+    // Then rewrite the var type and associated expression
+    Var var = Downcast<Var>(this->Mutate(op->var));
+    VarNode* mutable_var = const_cast<VarNode*>((op->var).as<VarNode>());
+    mutable_var->type_annotation = GetType(value);
+    mutable_var->checked_type_ = mutable_var->type_annotation;
+
+    // Mutate body last as it may depend on previous results
+    Expr body = this->Mutate(op->body);
+    return Let(var, value, body, op->span);
+  }
+};
+
+Expr ToMixedPrecision(const Expr& expr, const DataType& mixed_precision_type,
+                      bool ignore_missing_ops, bool warn_missing_ops) {
+  MixedPrecisionPass converter =
+      MixedPrecisionPass(mixed_precision_type, ignore_missing_ops, warn_missing_ops);
+  auto result = converter.Mutate(expr);
+  return result;
+}
+
+namespace transform {
+
+Pass ToMixedPrecision(DataType mixed_precision_type, bool ignore_missing_ops,
+                      bool warn_missing_ops) {
+  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
+      [=](Function f, IRModule m, PassContext pc) {
+        return Downcast<Function>(
+            ToMixedPrecision(f, mixed_precision_type, ignore_missing_ops, warn_missing_ops));
+      };
+  return CreateFunctionPass(pass_func, 10, "ToMixedPrecision", {});

Review comment:
       Is 10 the right level? It should be 0 IMHO, as this pass is supposed to be used by user manually instead of being managed by the pass manager.

##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,409 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed floating point precision for relay graphs. i.e. turn a graph into fp16.
+ *
+ */
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
+// savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to
+// justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to
+// numerical reasons.
+enum MixedTypeConversionCategory : int {
+  MIXED_PRECISION_ALWAYS = 0,
+  MIXED_PRECISION_FOLLOW = 1,
+  MIXED_PRECISION_NEVER = 2
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// Return array is of type : [MixedTypeConversionCategory (int), String, String]
+// The fields are          : [ConversionCategory, accumulation_datatype, output_datatype]
+// Call is a call node, DataType is the mixed precision type
+using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc<Array<ObjectRef>(
+    const Call& call_node, const std::string& target_dtype_str)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+
+  // The target datatype we want to convert to e.g. FP16
+  const DataType mixed_precision_type;
+
+  // If false, throws a fatal error if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool ignore_missing_ops;
+
+  // If true, emits a warning if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool warn_missing_ops;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs cur_attrs = call->attrs;
+    if (cur_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = cur_attrs.as<Conv1DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv1DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DeformableConv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DenseAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<BatchMatmulAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = cur_attrs.as<InitOpAttrs>()) {
+        return ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return cur_attrs;
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    DataType cur_type = (attrs->out_dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->out_dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    DataType cur_type = (attrs->dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only target mixed precision type elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == mixed_precision_type;
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsMixedPrecisionType(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    if (expr_dtype == wanted_dtype) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;
+
+    // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
+    const ExprNode* new_expr_node = result.as<ExprNode>();
+    cast_nodes_cache[{new_expr_node, expr_dtype}] = expr;
+    return result;
+  }
+
+  Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
+    /* Helper for casting arguments to call_nodes handling all relevant cases. */
+    if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+      return CachedCast(expr, tensor_type->dtype, wanted_dtype);
+    } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
+        Expr tuple_element = GetField(expr, i);
+        Type tuple_element_dtype = (tuple_type->fields)[i];
+        Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(tuple_element);
+      }
+      return all_same ? expr : Tuple(new_expr);
+    } else {
+      LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
+      return expr;
+    }
+  }
+
+  std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args,
+                                                  const Array<Type>& cur_arg_types,
+                                                  const DataType& wanted_dtype) {
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    for (size_t i = 0; i < cur_args.size(); i++) {
+      Expr cur_arg = cur_args[i];
+      Type cur_arg_type = cur_arg_types[i];
+      Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+    }
+    return {new_args, new_arg_types};
+  }
+
+ public:
+  using MixedModeMutator::VisitExpr_;
+
+  explicit MixedPrecisionPass(DataType mixed_precision_type = DataType::Float(16),
+                              bool ignore_missing_ops = true, bool warn_missing_ops = true)
+      : MixedModeMutator(),
+        mixed_precision_type(mixed_precision_type),
+        ignore_missing_ops(ignore_missing_ops),
+        warn_missing_ops(warn_missing_ops) {
+    if (!mixed_precision_type.is_float() && !mixed_precision_type.is_bfloat16())
+      LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat16 got "
+                 << mixed_precision_type;
+  }
+
+  Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final {
+    const CallNode* post_call_node = post.as<CallNode>();
+    if (!post_call_node) {
+      LOG(FATAL) << "Expected a CallNode for the rewrite got " << post;
+    }
+
+    Expr cur_op = post_call_node->op;
+
+    // Get info on the operation being called:
+    // conversion category (int), accumulation dtype (str), output dtype (str)
+    MixedTypeConversionCategory initial_category;
+    DataType accumulation_dtype, output_dtype;
+    if (cur_op.as<FunctionNode>()) {
+      // Avoid messing with functions to avoid changing signature
+      initial_category = MIXED_PRECISION_NEVER;
+      accumulation_dtype = DataType::Float(32);
+      output_dtype = DataType::Float(32);
+    } else if (cur_op.as<OpNode>()) {
+      static auto attr_map =
+          Op::GetAttrMap<FTVMMixedPrecisionConversionType>("FTVMMixedPrecisionConversionType");
+      Op op = Downcast<Op>(cur_op);
+      if (attr_map.count(op)) {
+        // Calculate the conversion category and dtypes from registered attribute.
+        FTVMMixedPrecisionConversionType func = attr_map[op];
+        Array<ObjectRef> op_descriptor =
+            func(GetRef<Call>(pre_call_node), DLDataType2String(mixed_precision_type));

Review comment:
       Need to validate `op_descriptor` (3 elements of (int, str, str)).

##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,409 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed floating point precision for relay graphs. i.e. turn a graph into fp16.
+ *
+ */
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
+// savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to
+// justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to
+// numerical reasons.
+enum MixedTypeConversionCategory : int {
+  MIXED_PRECISION_ALWAYS = 0,
+  MIXED_PRECISION_FOLLOW = 1,
+  MIXED_PRECISION_NEVER = 2
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// Return array is of type : [MixedTypeConversionCategory (int), String, String]
+// The fields are          : [ConversionCategory, accumulation_datatype, output_datatype]
+// Call is a call node, DataType is the mixed precision type
+using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc<Array<ObjectRef>(
+    const Call& call_node, const std::string& target_dtype_str)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+
+  // The target datatype we want to convert to e.g. FP16
+  const DataType mixed_precision_type;
+
+  // If false, throws a fatal error if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool ignore_missing_ops;
+
+  // If true, emits a warning if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool warn_missing_ops;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs cur_attrs = call->attrs;
+    if (cur_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = cur_attrs.as<Conv1DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv1DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DeformableConv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DenseAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<BatchMatmulAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = cur_attrs.as<InitOpAttrs>()) {
+        return ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return cur_attrs;
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    DataType cur_type = (attrs->out_dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->out_dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    DataType cur_type = (attrs->dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only target mixed precision type elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == mixed_precision_type;
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsMixedPrecisionType(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    if (expr_dtype == wanted_dtype) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;
+
+    // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
+    const ExprNode* new_expr_node = result.as<ExprNode>();
+    cast_nodes_cache[{new_expr_node, expr_dtype}] = expr;
+    return result;
+  }
+
+  Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
+    /* Helper for casting arguments to call_nodes handling all relevant cases. */
+    if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+      return CachedCast(expr, tensor_type->dtype, wanted_dtype);
+    } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
+        Expr tuple_element = GetField(expr, i);
+        Type tuple_element_dtype = (tuple_type->fields)[i];
+        Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(tuple_element);
+      }
+      return all_same ? expr : Tuple(new_expr);
+    } else {
+      LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
+      return expr;
+    }
+  }
+
+  std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args,
+                                                  const Array<Type>& cur_arg_types,
+                                                  const DataType& wanted_dtype) {
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    for (size_t i = 0; i < cur_args.size(); i++) {
+      Expr cur_arg = cur_args[i];
+      Type cur_arg_type = cur_arg_types[i];
+      Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+    }
+    return {new_args, new_arg_types};
+  }
+
+ public:
+  using MixedModeMutator::VisitExpr_;
+
+  explicit MixedPrecisionPass(DataType mixed_precision_type = DataType::Float(16),
+                              bool ignore_missing_ops = true, bool warn_missing_ops = true)
+      : MixedModeMutator(),
+        mixed_precision_type(mixed_precision_type),
+        ignore_missing_ops(ignore_missing_ops),
+        warn_missing_ops(warn_missing_ops) {
+    if (!mixed_precision_type.is_float() && !mixed_precision_type.is_bfloat16())
+      LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat16 got "
+                 << mixed_precision_type;
+  }
+
+  Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final {
+    const CallNode* post_call_node = post.as<CallNode>();
+    if (!post_call_node) {
+      LOG(FATAL) << "Expected a CallNode for the rewrite got " << post;
+    }
+
+    Expr cur_op = post_call_node->op;
+
+    // Get info on the operation being called:
+    // conversion category (int), accumulation dtype (str), output dtype (str)
+    MixedTypeConversionCategory initial_category;
+    DataType accumulation_dtype, output_dtype;
+    if (cur_op.as<FunctionNode>()) {
+      // Avoid messing with functions to avoid changing signature
+      initial_category = MIXED_PRECISION_NEVER;
+      accumulation_dtype = DataType::Float(32);
+      output_dtype = DataType::Float(32);
+    } else if (cur_op.as<OpNode>()) {
+      static auto attr_map =
+          Op::GetAttrMap<FTVMMixedPrecisionConversionType>("FTVMMixedPrecisionConversionType");
+      Op op = Downcast<Op>(cur_op);
+      if (attr_map.count(op)) {
+        // Calculate the conversion category and dtypes from registered attribute.
+        FTVMMixedPrecisionConversionType func = attr_map[op];
+        Array<ObjectRef> op_descriptor =
+            func(GetRef<Call>(pre_call_node), DLDataType2String(mixed_precision_type));
+
+        int64_t op_conversion_type = Downcast<Integer>(op_descriptor[0])->value;
+        initial_category = static_cast<MixedTypeConversionCategory>(op_conversion_type);
+        accumulation_dtype = DataType(String2DLDataType(Downcast<String>(op_descriptor[1])));
+        output_dtype = DataType(String2DLDataType(Downcast<String>(op_descriptor[2])));
+      } else {
+        if (!ignore_missing_ops) LOG(FATAL) << "Op " << op->name << " not in conversion lists!";
+        if (warn_missing_ops) LOG(WARNING) << "Op " << op->name << " not in conversion lists!";
+
+        // If not registered, by default assume is a generic FOLLOW operation.
+        initial_category = MIXED_PRECISION_FOLLOW;
+        accumulation_dtype = DataType::Float(16);
+        output_dtype = DataType::Float(16);

Review comment:
       Shouldn't these be `mixed_precision_type`?

##########
File path: tests/python/relay/test_to_mixed_precision.py
##########
@@ -0,0 +1,446 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for testing ToMixedPrecision pass"""

Review comment:
       Could you benchmark how long does this test set be executed? Since this test includes e2e model execution (i.e., LSTM), I'm a bit worry it may take too long and slowdown the CI.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on a change in pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r647779530



##########
File path: src/relay/transforms/fp32_to_fp16.cc
##########
@@ -0,0 +1,330 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file fp32_to_fp16.cc
+ * \brief Rewrite a graph into an fp16 form.
+ */
+#include "fp32_to_fp16.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    return h1 ^ (h2 << 1);

Review comment:
       Done.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] anijain2305 commented on pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
anijain2305 commented on pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#issuecomment-855073961


   Thanks for the useful feature. Is this ready for review?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#issuecomment-858045479


   Hey folks, covered the simple changes requested, here is the list of more involved changes along with the associated reviewer. Several of these changes were planned to be future PRs but it might be best to just commit this correctly the first time (since it doesn't really touch other files):
   - [ ] Support other floating point types out of the box (e.g. bfloat16)
   - [ ] Naming of things (pass, GREEN/RED/GRAY, etc.)
   - [ ] Python interface for Coloring/Accumulation logic
   - [ ] How to register ops for coloring
   
   Let me know if I missed anything


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbrookhart commented on a change in pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r647748964



##########
File path: tests/python/frontend/mxnet/test_forward.py
##########
@@ -1223,6 +1221,8 @@ def verify(shape, axis=1, fix_gamma=False):
 
 @tvm.testing.uses_gpu
 def test_forward_instance_norm():
+    np.random.seed(90)
+

Review comment:
       Tianqi prefers we don't set random seeds to try to find intermittent bugs across CI runs

##########
File path: src/relay/transforms/fp32_to_fp16.cc
##########
@@ -0,0 +1,330 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file fp32_to_fp16.cc
+ * \brief Rewrite a graph into an fp16 form.
+ */
+#include "fp32_to_fp16.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    return h1 ^ (h2 << 1);

Review comment:
       If I remember correctly, xor hash combine is pretty prone to hash conflicts? Maybe use the boost approach? `return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2))`

##########
File path: src/relay/transforms/fp32_to_fp16.cc
##########
@@ -0,0 +1,332 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file fp32_to_fp16.cc
+ * \brief Rewrite a graph into an fp16 form.
+ */
+#include "fp32_to_fp16.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    return h1 ^ (h2 << 1);
+  }
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// A function which maps CallNodes to their initial conversion color
+using ColorFunc = std::function<FP16ConversionCategory(const CallNode*)>;
+
+// A function which maps green CallNodes to wanted accumulation and output dtypes
+using OutputDtypeFunc = std::function<FP16OpDType(const CallNode*)>;
+
+class AmpGraphCreator : public ExprMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+  const ColorFunc colorer;
+  const OutputDtypeFunc output_dtype_func;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs new_attrs = Attrs(call->attrs);
+    if (new_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = new_attrs.as<Conv1DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv1DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DeformableConv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DenseAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<BatchMatmulAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = new_attrs.as<InitOpAttrs>()) {
+        ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return new_attrs;
+  }
+
+  template <typename T>
+  void ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    T* mutable_attrs = const_cast<T*>(attrs);
+    if ((mutable_attrs->out_dtype).is_float()) mutable_attrs->out_dtype = accumulation_dtype;
+  }
+
+  template <typename T>
+  void ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    T* mutable_attrs = const_cast<T*>(attrs);
+    if ((mutable_attrs->dtype).is_float()) mutable_attrs->dtype = accumulation_dtype;
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsFP16Type(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only fp16 elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == DataType::Float(16);
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsFP16Type(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = expr_dtype == wanted_dtype ? expr : Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;
+
+    // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
+    const ExprNode* new_expr_node = result.as<ExprNode>();
+    cast_nodes_cache[{new_expr_node, expr_dtype}] = expr;
+    return result;
+  }
+
+  Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
+    /* Helper for casting arguments to call_nodes handling all relevant cases. */
+    if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+      return CachedCast(expr, tensor_type->dtype, wanted_dtype);
+    } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
+        Expr tuple_element = GetField(expr, i);
+        Type tuple_element_dtype = (tuple_type->fields)[i];
+        Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(tuple_element);
+      }
+      return all_same ? expr : Tuple(new_expr);
+    } else {
+      LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
+      return expr;
+    }
+  }
+
+  std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args,
+                                                  const Array<Type>& cur_arg_types,
+                                                  const DataType& wanted_dtype) {
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    for (size_t i = 0; i < cur_args.size(); i++) {
+      Expr cur_arg = cur_args[i];
+      Type cur_arg_type = cur_arg_types[i];
+      Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+    }
+    return {new_args, new_arg_types};
+  }
+
+ public:
+  explicit AmpGraphCreator(ColorFunc colorer, OutputDtypeFunc output_dtype_func)
+      : ExprMutator(), colorer(colorer), output_dtype_func(output_dtype_func) {}

Review comment:
       Since you're using the recursive mutator here, you might run into stack overflows on larger models. I haven't looked at this pass in much detail yet, is it possible to do with a post-order traversal (or MixedMode Mutator?)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on a change in pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r648609224



##########
File path: src/relay/transforms/fp32_to_fp16.cc
##########
@@ -0,0 +1,337 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file fp32_to_fp16.cc
+ * \brief Rewrite a graph into an fp16 form.
+ */
+#include "fp32_to_fp16.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// A function which maps CallNodes to their initial conversion color
+using ColorFunc = std::function<FP16ConversionCategory(const CallNode*)>;
+
+// A function which maps green CallNodes to wanted accumulation and output dtypes
+using OutputDtypeFunc = std::function<FP16OpDType(const CallNode*)>;
+
+class AmpGraphCreator : public ExprMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+  const ColorFunc colorer;
+  const OutputDtypeFunc output_dtype_func;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs new_attrs = Attrs(call->attrs);
+    if (new_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = new_attrs.as<Conv1DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv1DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DeformableConv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DenseAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<BatchMatmulAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = new_attrs.as<InitOpAttrs>()) {
+        ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return new_attrs;
+  }
+
+  template <typename T>
+  void ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    T* mutable_attrs = const_cast<T*>(attrs);
+
+    DataType cur_type = (mutable_attrs->out_dtype);
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->out_dtype = accumulation_dtype;
+  }
+
+  template <typename T>
+  void ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    T* mutable_attrs = const_cast<T*>(attrs);
+    DataType cur_type = (mutable_attrs->dtype);
+
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->dtype = accumulation_dtype;
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsFP16Type(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only fp16 elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == DataType::Float(16);
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsFP16Type(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = expr_dtype == wanted_dtype ? expr : Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;
+
+    // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
+    const ExprNode* new_expr_node = result.as<ExprNode>();
+    cast_nodes_cache[{new_expr_node, expr_dtype}] = expr;
+    return result;
+  }
+
+  Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
+    /* Helper for casting arguments to call_nodes handling all relevant cases. */
+    if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+      return CachedCast(expr, tensor_type->dtype, wanted_dtype);
+    } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
+        Expr tuple_element = GetField(expr, i);
+        Type tuple_element_dtype = (tuple_type->fields)[i];
+        Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(tuple_element);
+      }
+      return all_same ? expr : Tuple(new_expr);
+    } else {
+      LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
+      return expr;
+    }
+  }
+
+  std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args,
+                                                  const Array<Type>& cur_arg_types,
+                                                  const DataType& wanted_dtype) {
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    for (size_t i = 0; i < cur_args.size(); i++) {
+      Expr cur_arg = cur_args[i];
+      Type cur_arg_type = cur_arg_types[i];
+      Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+    }
+    return {new_args, new_arg_types};
+  }
+
+ public:
+  explicit AmpGraphCreator(ColorFunc colorer, OutputDtypeFunc output_dtype_func)
+      : ExprMutator(), colorer(colorer), output_dtype_func(output_dtype_func) {}
+
+  Expr VisitExpr_(const CallNode* call_node) {
+    FP16ConversionCategory initial_color = colorer(call_node);
+    auto new_op = this->Mutate(call_node->op);
+
+    // Mutate arguments to FP16 form first if possible and keep track of whether all floating point
+    // tensors are in FP16 form already. This is useful for propagating color.
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    bool all_args_fp16_compatible = true;
+    for (Expr arg : call_node->args) {
+      Expr new_arg = this->Mutate(arg);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+
+      if (all_args_fp16_compatible) {
+        // We can cast Vars and Constants to the right types so don't care about the types.
+        bool is_fp16_compatible = IsFP16Type(new_arg_type, true) || arg->IsInstance<VarNode>() ||
+                                  arg->IsInstance<ConstantNode>();
+        all_args_fp16_compatible &= is_fp16_compatible;
+      }
+    }
+
+    // Determine the final color.
+    FP16ConversionCategory final_color;
+    if (initial_color == GRAY) {
+      final_color = all_args_fp16_compatible ? GREEN : RED;

Review comment:
       An example with concat.
   
   We have two branches whose outputs are fed into concat.
   
   The first branch has a RED operation and returns an FP32 tensor.
   The second branch returns an FP16 tensor.
   
   Now that I say this, it might be better to be a bit smarter about GRAY ops when we have heterogeneous floating point types coming in.
   
   E.g. let's say we had a concat with 10 fp16 args and 1 fp32 arg. It would be wasteful to default convert everything to fp32 and set the color as RED in this case. 
   
   I will change this so the number of fp16/fp32 args are taken into account. If there is a majority of fp16 or a tie we color GREEN else we color RED. Thoughts?

##########
File path: src/relay/transforms/fp32_to_fp16.cc
##########
@@ -0,0 +1,337 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file fp32_to_fp16.cc
+ * \brief Rewrite a graph into an fp16 form.
+ */
+#include "fp32_to_fp16.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// A function which maps CallNodes to their initial conversion color
+using ColorFunc = std::function<FP16ConversionCategory(const CallNode*)>;
+
+// A function which maps green CallNodes to wanted accumulation and output dtypes
+using OutputDtypeFunc = std::function<FP16OpDType(const CallNode*)>;
+
+class AmpGraphCreator : public ExprMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+  const ColorFunc colorer;
+  const OutputDtypeFunc output_dtype_func;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs new_attrs = Attrs(call->attrs);
+    if (new_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = new_attrs.as<Conv1DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv1DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DeformableConv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DenseAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<BatchMatmulAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = new_attrs.as<InitOpAttrs>()) {
+        ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return new_attrs;
+  }
+
+  template <typename T>
+  void ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    T* mutable_attrs = const_cast<T*>(attrs);
+
+    DataType cur_type = (mutable_attrs->out_dtype);
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->out_dtype = accumulation_dtype;
+  }
+
+  template <typename T>
+  void ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    T* mutable_attrs = const_cast<T*>(attrs);
+    DataType cur_type = (mutable_attrs->dtype);
+
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->dtype = accumulation_dtype;
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsFP16Type(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only fp16 elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == DataType::Float(16);
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsFP16Type(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = expr_dtype == wanted_dtype ? expr : Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;
+
+    // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
+    const ExprNode* new_expr_node = result.as<ExprNode>();
+    cast_nodes_cache[{new_expr_node, expr_dtype}] = expr;
+    return result;
+  }
+
+  Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
+    /* Helper for casting arguments to call_nodes handling all relevant cases. */
+    if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+      return CachedCast(expr, tensor_type->dtype, wanted_dtype);
+    } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
+        Expr tuple_element = GetField(expr, i);
+        Type tuple_element_dtype = (tuple_type->fields)[i];
+        Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(tuple_element);
+      }
+      return all_same ? expr : Tuple(new_expr);
+    } else {
+      LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
+      return expr;
+    }
+  }
+
+  std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args,
+                                                  const Array<Type>& cur_arg_types,
+                                                  const DataType& wanted_dtype) {
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    for (size_t i = 0; i < cur_args.size(); i++) {
+      Expr cur_arg = cur_args[i];
+      Type cur_arg_type = cur_arg_types[i];
+      Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+    }
+    return {new_args, new_arg_types};
+  }
+
+ public:
+  explicit AmpGraphCreator(ColorFunc colorer, OutputDtypeFunc output_dtype_func)
+      : ExprMutator(), colorer(colorer), output_dtype_func(output_dtype_func) {}
+
+  Expr VisitExpr_(const CallNode* call_node) {
+    FP16ConversionCategory initial_color = colorer(call_node);
+    auto new_op = this->Mutate(call_node->op);
+
+    // Mutate arguments to FP16 form first if possible and keep track of whether all floating point
+    // tensors are in FP16 form already. This is useful for propagating color.
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    bool all_args_fp16_compatible = true;
+    for (Expr arg : call_node->args) {
+      Expr new_arg = this->Mutate(arg);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+
+      if (all_args_fp16_compatible) {
+        // We can cast Vars and Constants to the right types so don't care about the types.
+        bool is_fp16_compatible = IsFP16Type(new_arg_type, true) || arg->IsInstance<VarNode>() ||
+                                  arg->IsInstance<ConstantNode>();
+        all_args_fp16_compatible &= is_fp16_compatible;
+      }
+    }
+
+    // Determine the final color.
+    FP16ConversionCategory final_color;
+    if (initial_color == GRAY) {
+      final_color = all_args_fp16_compatible ? GREEN : RED;
+    } else {
+      final_color = initial_color;
+    }
+
+    // Create the new arguments to the call.
+    DataType wanted_arg_dtypes = final_color == GREEN ? DataType::Float(16) : DataType::Float(32);
+    auto call_args_and_types = CastAllArgs(new_args, new_arg_types, wanted_arg_dtypes);
+
+    Array<Expr> call_args = call_args_and_types.first;
+    Array<Type> call_arg_types;
+
+    if (call_node->op.as<FunctionNode>()) {
+      // Function Nodes don't store type info in the Call, it should be a []
+      call_arg_types = call_node->type_args;
+    } else {
+      call_arg_types = call_args_and_types.second;
+    }
+
+    // Finally create the new attributes.
+    if (final_color == GREEN) {
+      FP16OpDType output_dtypes = output_dtype_func(call_node);
+
+      Attrs new_attrs = GetNewAttrs(call_node, output_dtypes.accumulation_dtype);
+      Expr output = Call(new_op, call_args, new_attrs, call_arg_types, call_node->span);
+      if (output_dtypes.accumulation_dtype != output_dtypes.output_dtype) {
+        output = CastArg(output, GetType(output), output_dtypes.output_dtype);
+      }
+      return output;
+    } else {
+      return Call(new_op, call_args, call_node->attrs, call_arg_types, call_node->span);
+    }

Review comment:
       Done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on a change in pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r648617582



##########
File path: tests/python/relay/test_fp32_to_fp16_transform.py
##########
@@ -0,0 +1,328 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for testing FP32 -> FP16 pass"""
+from typing import Any, Dict, List
+
+import numpy as np
+import tvm
+from tvm import relay
+from tvm.relay.testing import lstm
+from tvm.relay.transform import AMPRewrite
+from tvm.relay.transform.transform import InferType
+
+
+def run_module(mod: tvm.runtime.Module, mod_params: Dict[str, Any]) -> List:
+    dev = tvm.device("llvm", 0)
+    intrp = relay.create_executor("debug", mod, device=dev, target="llvm")
+    result = intrp.evaluate()(**mod_params)
+    if isinstance(result, tvm.runtime.container.ADT):
+        result = [r.asnumpy() for r in result]
+        return result
+    else:
+        return [result.asnumpy()]
+
+
+def verify_fp32_fp16_output_close(
+    mod: tvm.runtime.Module, mod_params: Dict[str, Any], rtol: float = 1e-3, atol: float = 0
+) -> tvm.runtime.Module:
+    mod = InferType()(mod)
+    result_fp32 = run_module(mod, mod_params)
+    fp16_mod = AMPRewrite()(mod)
+    result_fp16 = run_module(fp16_mod, mod_params)
+
+    # Ensure the results are close
+    for fp32, fp16 in zip(result_fp32, result_fp16):
+        np.testing.assert_allclose(fp32, fp16, rtol=rtol, atol=atol)
+
+    return fp16_mod
+
+
+def test_lstm():
+    """A small stress test on a single unrolled lstm unit.
+
+    Has internal functions and let statements the pass must work on.
+    """
+    np.random.seed(5628)
+    units = 3
+    iterations = 5
+    mod, mod_params = lstm.get_workload(iterations=iterations, num_hidden=units)
+
+    # This is an unrolled lstm so each data should be the previous results but
+    # we don't care, we just want to stress test things.
+    for i in range(iterations):
+        mod_params["data" if i == 0 else f"data{i}"] = np.random.uniform(
+            -10, 10, (1, units)
+        ).astype("float32")
+
+    verify_fp32_fp16_output_close(mod, mod_params, rtol=0.01, atol=0.01)
+
+
+def test_convert_single_conv():
+    """Conv is a green listed operation meaning it will always use fp16 workload.
+
+    By default it accumulates to fp32 and outputs fp16.
+    """
+    np.random.seed(208)
+
+    data_shape = (1, 3, 32, 32)
+    weight_shape = (5, 3, 3, 3)
+    data = relay.var("data", shape=data_shape, dtype="float32")
+    weight = relay.var("weight", shape=weight_shape, dtype="float32")
+    conv = relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float32")
+    mod = tvm.IRModule.from_expr(conv)
+    mod = tvm.relay.transform.InferType()(mod)
+
+    mod_params = {
+        "data": np.random.uniform(-1, 1, size=data_shape).astype("float32"),
+        "weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"),
+    }
+    fp16_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.01, rtol=1e-3)
+
+    expected_mod = tvm.IRModule.from_expr(
+        relay.cast(
+            relay.nn.conv2d(
+                relay.cast(data, "float16"),
+                relay.cast(weight, "float16"),
+                strides=(1, 1),
+                padding=(1, 1),
+                out_dtype="float32",
+            ),
+            "float16",
+        )
+    )
+    expected_mod = tvm.relay.transform.InferType()(expected_mod)
+
+    assert not tvm.ir.structural_equal(fp16_mod, mod)
+    assert tvm.ir.structural_equal(fp16_mod, expected_mod)
+
+
+def test_convert_conv_bn():
+    """Conv is green and batch norm is gray. As Conv should output fp16 batch_norm should be green."""
+    np.random.seed(208)
+
+    data_shape = (1, 3, 32, 32)
+    weight_shape = (5, 3, 3, 3)
+    data = relay.var("data", shape=data_shape, dtype="float32")
+    weight = relay.var("weight", shape=weight_shape, dtype="float32")
+    conv = relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float32")
+
+    bn_shape = [5]
+    gamma = relay.var("gamma", shape=bn_shape)
+    beta = relay.var("beta", shape=bn_shape)
+    moving_mean = relay.var("moving_mean", shape=bn_shape)
+    moving_var = relay.var("moving_var", shape=bn_shape)
+    bn = relay.nn.batch_norm(conv, gamma, beta, moving_mean, moving_var)
+    mod = tvm.IRModule.from_expr(bn[0])
+    mod = tvm.relay.transform.InferType()(mod)
+
+    mod_params = {
+        "data": np.random.uniform(-1, 1, size=data_shape).astype("float32"),
+        "weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"),
+        "gamma": np.random.uniform(-1, 1, size=bn_shape).astype("float32"),
+        "beta": np.random.uniform(-1, 1, size=bn_shape).astype("float32"),
+        "moving_mean": np.random.uniform(-1, 1, size=bn_shape).astype("float32"),
+        "moving_var": np.random.uniform(-1, 1, size=bn_shape).astype("float32"),
+    }
+    fp16_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.01, rtol=1e-3)
+
+    # Creating expected module
+    data = relay.cast(relay.var("data", shape=data_shape), "float16")
+    weight = relay.cast(relay.var("weight", shape=weight_shape), "float16")
+    conv = relay.cast(
+        relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float32"),
+        "float16",
+    )
+
+    bn_shape = [5]
+    gamma = relay.cast(relay.var("gamma", shape=bn_shape), "float16")
+    beta = relay.cast(relay.var("beta", shape=bn_shape), "float16")
+    moving_mean = relay.cast(relay.var("moving_mean", shape=bn_shape), "float16")
+    moving_var = relay.cast(relay.var("moving_var", shape=bn_shape), "float16")
+    bn = relay.nn.batch_norm(conv, gamma, beta, moving_mean, moving_var)
+
+    expected_mod = tvm.IRModule.from_expr(bn[0])
+    expected_mod = tvm.relay.transform.InferType()(expected_mod)
+    assert not tvm.ir.structural_equal(fp16_mod, mod)
+    assert tvm.ir.structural_equal(fp16_mod, expected_mod)
+
+
+def test_do_not_convert_softmax():
+    """Softmax is a red listed operation and therefore should never be fp16."""
+    np.random.seed(209)
+    shape = [1, 2, 3]
+    a = relay.var("a", shape=shape)
+    b = relay.nn.softmax(a)
+    mod = tvm.IRModule.from_expr(b)
+    mod = tvm.relay.transform.InferType()(mod)
+
+    mod_params = {
+        "a": np.random.uniform(-1, 1, size=shape).astype("float32"),
+    }
+    output_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.0, rtol=0)
+    assert tvm.ir.structural_equal(mod, output_mod)
+
+
+def test_green_gray_propagates_simple():
+    """Conv is a green listed operation, while addition is gray.
+
+    As Conv outputs fp16 the add should be done in fp16.
+    """
+    np.random.seed(210)
+    data_shape = (1, 3, 32, 32)
+    weight_shape = (5, 3, 3, 3)
+    data = relay.var("data", shape=data_shape, dtype="float32")
+    weight = relay.var("weight", shape=weight_shape, dtype="float32")
+    conv = relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float32")
+    conv = conv + conv
+    mod = tvm.IRModule.from_expr(conv)
+    mod = tvm.relay.transform.InferType()(mod)
+
+    mod_params = {
+        "data": np.random.uniform(-1, 1, size=data_shape).astype("float32"),
+        "weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"),
+    }
+    fp16_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.01, rtol=1e-3)
+
+    conv_expr = relay.cast(
+        relay.nn.conv2d(
+            relay.cast(data, "float16"),
+            relay.cast(weight, "float16"),
+            strides=(1, 1),
+            padding=(1, 1),
+            out_dtype="float32",
+        ),
+        "float16",
+    )
+    expected_mod = tvm.IRModule.from_expr(conv_expr + conv_expr)
+    expected_mod = tvm.relay.transform.InferType()(expected_mod)
+
+    assert not tvm.ir.structural_equal(fp16_mod, mod)
+    assert tvm.ir.structural_equal(fp16_mod, expected_mod)
+
+
+def test_red_gray_propagates_simple():
+    """Everything after a softmax should be in FP32 (exception green colored ops)"""
+    np.random.seed(211)
+    shape = [1, 2, 3]
+    a = relay.var("a", shape=shape)
+    b = relay.nn.softmax(a)
+    c = b + b
+    mod = tvm.IRModule.from_expr(c)
+    mod = tvm.relay.transform.InferType()(mod)
+
+    mod_params = {
+        "a": np.random.uniform(-1, 1, size=shape).astype("float32"),
+    }
+    output_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.0, rtol=0.0)
+
+    assert tvm.ir.structural_equal(mod, output_mod)
+
+
+def test_let_statement_simple():
+    """A 'simple' let statement example.
+
+    Noticable is the mutation of the bound variable types.
+    """
+    np.random.seed(211)
+    var1 = relay.var("var1", shape=[1, 20])
+    var2 = relay.var("var2", shape=[1, 20])
+
+    data = relay.var("data", shape=[1, 20])
+    weight = relay.var("weight", shape=[20, 20])
+
+    r1 = var1 + var1
+
+    r2 = var2 + var2
+    let2 = relay.Let(var2, relay.nn.dense(r1, weight, units=20), r2)
+    let1 = relay.Let(var1, relay.nn.dense(data, weight, units=20), let2)
+
+    mod = tvm.IRModule.from_expr(let1)
+    mod_params = {
+        "data": np.random.uniform(-1, 1, size=[1, 20]).astype("float32"),
+        "weight": np.random.uniform(-1, 1, size=[20, 20]).astype("float32"),
+    }
+    output_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.01, rtol=0.01)
+
+    # Construct expected structure
+    var1 = relay.var("var1", shape=[1, 20], dtype="float16")
+    var2 = relay.var("var2", shape=[1, 20], dtype="float16")
+    data = relay.cast(relay.var("data", shape=[1, 20]), "float16")
+    weight = relay.cast(relay.var("weight", shape=[20, 20]), "float16")
+    r1 = var1 + var1
+    r2 = var2 + var2
+    let2 = relay.Let(
+        var2,
+        relay.cast(relay.nn.dense(r1, weight, units=20, out_dtype="float32"), "float16"),
+        r2,
+    )
+    let1 = relay.Let(
+        var1,
+        relay.cast(relay.nn.dense(data, weight, units=20, out_dtype="float32"), "float16"),
+        let2,
+    )
+    expected_mod = tvm.IRModule.from_expr(let1)
+    expected_mod = InferType()(expected_mod)
+
+    assert tvm.ir.structural_equal(expected_mod, output_mod)
+
+
+def test_where_simple():
+    data = relay.var("data", shape=[1, 20])
+    weight = relay.var("weight", shape=[20, 20])
+    a = relay.nn.dense(data, weight, units=20)
+    b = relay.where(data, a, a)
+    mod = tvm.IRModule.from_expr(b)
+    mod_params = {
+        "data": np.random.uniform(-1, 1, size=[1, 20]).astype("float32"),
+        "weight": np.random.uniform(-1, 1, size=[20, 20]).astype("float32"),
+    }
+
+    output_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.01, rtol=0.01)
+
+    # Create expected module
+    data = relay.cast(relay.var("data", shape=[1, 20]), "float16")
+    weight = relay.cast(relay.var("weight", shape=[20, 20]), "float16")
+    a = relay.cast(relay.nn.dense(data, weight, units=20, out_dtype="float32"), "float16")
+    b = relay.where(data, a, a)
+    expected_mod = tvm.IRModule.from_expr(b)
+    expected_mod = InferType()(expected_mod)
+
+    assert tvm.ir.structural_equal(expected_mod, output_mod)
+
+
+def test_batch_matmul_simple():
+    """Batch matmul is a special case where we try to accumulate to fp16.
+
+    This is due to the fact heterogenous accumulation dtypes does not work
+    on all platforms at the moment.
+    """
+    data = relay.var("data", shape=[1, 1, 20])
+    weight = relay.var("weight", shape=[1, 20, 20])
+    a = relay.nn.batch_matmul(data, weight)
+    mod = tvm.IRModule.from_expr(a)
+    mod_params = {
+        "data": np.random.uniform(-1, 1, size=[1, 1, 20]).astype("float32"),
+        "weight": np.random.uniform(-1, 1, size=[1, 20, 20]).astype("float32"),
+    }
+    output_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.01, rtol=0.01)
+    # Create expected module
+    data = relay.cast(relay.var("data", shape=[1, 1, 20]), "float16")
+    weight = relay.cast(relay.var("weight", shape=[1, 20, 20]), "float16")
+    a = relay.nn.batch_matmul(data, weight, out_dtype="float16")
+    expected_mod = tvm.IRModule.from_expr(a)
+    expected_mod = InferType()(expected_mod)
+    assert tvm.ir.structural_equal(expected_mod, output_mod)

Review comment:
       Done.

##########
File path: src/relay/transforms/fp32_to_fp16.cc
##########
@@ -0,0 +1,337 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file fp32_to_fp16.cc
+ * \brief Rewrite a graph into an fp16 form.
+ */
+#include "fp32_to_fp16.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// A function which maps CallNodes to their initial conversion color
+using ColorFunc = std::function<FP16ConversionCategory(const CallNode*)>;
+
+// A function which maps green CallNodes to wanted accumulation and output dtypes
+using OutputDtypeFunc = std::function<FP16OpDType(const CallNode*)>;
+
+class AmpGraphCreator : public ExprMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+  const ColorFunc colorer;
+  const OutputDtypeFunc output_dtype_func;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs new_attrs = Attrs(call->attrs);
+    if (new_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = new_attrs.as<Conv1DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv1DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DeformableConv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DenseAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<BatchMatmulAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = new_attrs.as<InitOpAttrs>()) {
+        ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return new_attrs;
+  }
+
+  template <typename T>
+  void ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    T* mutable_attrs = const_cast<T*>(attrs);
+
+    DataType cur_type = (mutable_attrs->out_dtype);
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->out_dtype = accumulation_dtype;
+  }
+
+  template <typename T>
+  void ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    T* mutable_attrs = const_cast<T*>(attrs);
+    DataType cur_type = (mutable_attrs->dtype);
+
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->dtype = accumulation_dtype;
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsFP16Type(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only fp16 elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == DataType::Float(16);
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsFP16Type(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = expr_dtype == wanted_dtype ? expr : Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;
+
+    // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
+    const ExprNode* new_expr_node = result.as<ExprNode>();
+    cast_nodes_cache[{new_expr_node, expr_dtype}] = expr;
+    return result;
+  }
+
+  Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
+    /* Helper for casting arguments to call_nodes handling all relevant cases. */
+    if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+      return CachedCast(expr, tensor_type->dtype, wanted_dtype);
+    } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
+        Expr tuple_element = GetField(expr, i);
+        Type tuple_element_dtype = (tuple_type->fields)[i];
+        Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(tuple_element);
+      }
+      return all_same ? expr : Tuple(new_expr);
+    } else {
+      LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
+      return expr;
+    }
+  }
+
+  std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args,
+                                                  const Array<Type>& cur_arg_types,
+                                                  const DataType& wanted_dtype) {
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    for (size_t i = 0; i < cur_args.size(); i++) {
+      Expr cur_arg = cur_args[i];
+      Type cur_arg_type = cur_arg_types[i];
+      Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+    }
+    return {new_args, new_arg_types};
+  }
+
+ public:
+  explicit AmpGraphCreator(ColorFunc colorer, OutputDtypeFunc output_dtype_func)
+      : ExprMutator(), colorer(colorer), output_dtype_func(output_dtype_func) {}
+
+  Expr VisitExpr_(const CallNode* call_node) {
+    FP16ConversionCategory initial_color = colorer(call_node);
+    auto new_op = this->Mutate(call_node->op);
+
+    // Mutate arguments to FP16 form first if possible and keep track of whether all floating point
+    // tensors are in FP16 form already. This is useful for propagating color.
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    bool all_args_fp16_compatible = true;
+    for (Expr arg : call_node->args) {
+      Expr new_arg = this->Mutate(arg);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+
+      if (all_args_fp16_compatible) {
+        // We can cast Vars and Constants to the right types so don't care about the types.
+        bool is_fp16_compatible = IsFP16Type(new_arg_type, true) || arg->IsInstance<VarNode>() ||
+                                  arg->IsInstance<ConstantNode>();
+        all_args_fp16_compatible &= is_fp16_compatible;
+      }
+    }

Review comment:
       Done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] anijain2305 commented on a change in pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
anijain2305 commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r646841682



##########
File path: python/tvm/relay/transform/transform.py
##########
@@ -1145,6 +1144,21 @@ def AnnotateSpans():
     Returns
     -------
     ret : tvm.transform.Pass
-        The regsistered AnnotateSpans pass.
+        The registered AnnotateSpans pass.
     """
     return _ffi_api.AnnotateSpans()
+
+
+def RewriteFP16():

Review comment:
       Do you want to call it AMPRewriter?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] comaniac commented on a change in pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r650300650



##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,356 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed precision for relay graphs. i.e. turn a graph into fp16 form.
+ */
+#include "to_mixed_precision.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// A function which maps CallNodes to their initial conversion color
+using ColorFunc = std::function<MixedTypeConversionCategory(const CallNode*)>;
+
+// A function which maps MIXED_PRECISION_ALWAYS CallNodes to wanted accumulation and output dtypes
+using OutputDtypeFunc = std::function<MixedPrecisionOpOutDType(const CallNode*)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+  const ColorFunc colorer;
+  const OutputDtypeFunc output_dtype_func;
+  const DataType mixed_precision_type;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs new_attrs = Attrs(call->attrs);
+    if (new_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = new_attrs.as<Conv1DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv1DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DeformableConv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DenseAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<BatchMatmulAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = new_attrs.as<InitOpAttrs>()) {
+        ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return new_attrs;
+  }
+
+  template <typename T>
+  void ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    T* mutable_attrs = const_cast<T*>(attrs);
+
+    DataType cur_type = (mutable_attrs->out_dtype);
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->out_dtype = accumulation_dtype;
+  }
+
+  template <typename T>
+  void ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    T* mutable_attrs = const_cast<T*>(attrs);
+    DataType cur_type = (mutable_attrs->dtype);
+
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->dtype = accumulation_dtype;
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only target mixed precision type elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == mixed_precision_type;
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsMixedPrecisionType(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = expr_dtype == wanted_dtype ? expr : Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;

Review comment:
       I reviewed the cache mechanism and I think I got the idea. Here is the example I went through:
   
   Consider the op `A (out: fp32, want: fp16)`, the cache will look like the following after processing A's output:
   ```
   (A, fp16): cast_to_fp16
   (cast, fp32): A
   ```
   
   Now consider the followed op `B`:
   Case 1. If `B` wants fp32, then like you mentioned before, we query `(cast, fp32)` and get `A`, so it becomes `A -> B`.
   Case 2. If `B` wants fp16, then we query `(cast, fp16)`, which is missed and a new entry `(cast, fp16): cast` is created and returned, so it becomes `A -> cast -> B`.
   
   This mechanism seems working well, and the cache size should be reasonable as it only keeps pointers. Two possible improvements:
   1. Apparently, the cache entry `(cast, fp16): cast` in the example is not necessary. I think we can simply return `expr` when `expr_dtype == wanted_dtype`?
   2. The created `cast` ops may be useless, such as the one in case 1. Is it possible to create this op lazily? For example, when casting the output, we only create a cache entry but don't really create the node. Once the entry is queried by the followed ops for the first time, we create the cast node and update the cache.
   
   

##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,356 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed precision for relay graphs. i.e. turn a graph into fp16 form.
+ */
+#include "to_mixed_precision.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// A function which maps CallNodes to their initial conversion color
+using ColorFunc = std::function<MixedTypeConversionCategory(const CallNode*)>;
+
+// A function which maps MIXED_PRECISION_ALWAYS CallNodes to wanted accumulation and output dtypes
+using OutputDtypeFunc = std::function<MixedPrecisionOpOutDType(const CallNode*)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+  const ColorFunc colorer;
+  const OutputDtypeFunc output_dtype_func;
+  const DataType mixed_precision_type;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs new_attrs = Attrs(call->attrs);
+    if (new_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = new_attrs.as<Conv1DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv1DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DeformableConv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DenseAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<BatchMatmulAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = new_attrs.as<InitOpAttrs>()) {
+        ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return new_attrs;
+  }
+
+  template <typename T>
+  void ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    T* mutable_attrs = const_cast<T*>(attrs);
+
+    DataType cur_type = (mutable_attrs->out_dtype);
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->out_dtype = accumulation_dtype;
+  }
+
+  template <typename T>
+  void ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    T* mutable_attrs = const_cast<T*>(attrs);
+    DataType cur_type = (mutable_attrs->dtype);
+
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->dtype = accumulation_dtype;
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only target mixed precision type elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == mixed_precision_type;
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsMixedPrecisionType(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = expr_dtype == wanted_dtype ? expr : Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;

Review comment:
       I reviewed the cache mechanism and I think I got the idea. Here is the example I went through:
   
   Consider the op `A (out: fp32, want: fp16)`, the cache will look like the following after processing A's output:
   ```
   (A, fp16): cast
   (cast, fp32): A
   ```
   
   Now consider the followed op `B`:
   Case 1. If `B` wants fp32, then like you mentioned before, we query `(cast, fp32)` and get `A`, so it becomes `A -> B`.
   Case 2. If `B` wants fp16, then we query `(cast, fp16)`, which is missed and a new entry `(cast, fp16): cast` is created and returned, so it becomes `A -> cast -> B`.
   
   This mechanism seems working well, and the cache size should be reasonable as it only keeps pointers. Two possible improvements:
   1. Apparently, the cache entry `(cast, fp16): cast` in the example is not necessary. I think we can simply return `expr` when `expr_dtype == wanted_dtype`?
   2. The created `cast` ops may be useless, such as the one in case 1. Is it possible to create this op lazily? For example, when casting the output, we only create a cache entry but don't really create the node. Once the entry is queried by the followed ops for the first time, we create the cast node and update the cache.
   
   

##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,356 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed precision for relay graphs. i.e. turn a graph into fp16 form.
+ */
+#include "to_mixed_precision.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// A function which maps CallNodes to their initial conversion color
+using ColorFunc = std::function<MixedTypeConversionCategory(const CallNode*)>;
+
+// A function which maps MIXED_PRECISION_ALWAYS CallNodes to wanted accumulation and output dtypes
+using OutputDtypeFunc = std::function<MixedPrecisionOpOutDType(const CallNode*)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+  const ColorFunc colorer;
+  const OutputDtypeFunc output_dtype_func;
+  const DataType mixed_precision_type;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs new_attrs = Attrs(call->attrs);
+    if (new_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = new_attrs.as<Conv1DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv1DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DeformableConv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DenseAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<BatchMatmulAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = new_attrs.as<InitOpAttrs>()) {
+        ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return new_attrs;
+  }
+
+  template <typename T>
+  void ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    T* mutable_attrs = const_cast<T*>(attrs);
+
+    DataType cur_type = (mutable_attrs->out_dtype);
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->out_dtype = accumulation_dtype;
+  }
+
+  template <typename T>
+  void ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    T* mutable_attrs = const_cast<T*>(attrs);
+    DataType cur_type = (mutable_attrs->dtype);
+
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->dtype = accumulation_dtype;
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only target mixed precision type elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == mixed_precision_type;
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsMixedPrecisionType(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = expr_dtype == wanted_dtype ? expr : Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;

Review comment:
       I reviewed the cache mechanism and I think I got the idea. Here is the example I went through:
   
   Consider the op `A (out: fp32, want: fp16)`, the cache will look like the following after processing A's output:
   ```
   (A, fp16): cast
   (cast, fp32): A
   ```
   
   Now consider the followed op `B`:
   Case 1. If `B` wants fp32, then like you mentioned before, we query `(cast, fp32)` and get `A`, so it becomes `A -> B`.
   Case 2. If `B` wants fp16, then we query `(cast, fp16)`, which is missed and a new entry `(cast, fp16): cast` is created and returned, so it becomes `A -> cast -> B`.
   
   This mechanism seems working well, and the cache size should be reasonable as it only keeps pointers. Two possible improvements:
   1. Apparently, the cache entry `(cast, fp16): cast` in the example is not necessary. I think we can simply return `expr` when `expr_dtype == wanted_dtype`?
   2. The created `cast` ops may be useless, such as the one in case 1. Is it possible to create this op lazily? For example, when casting the output, we only create a cache entry but don't really create the node. Once the entry is queried by the followed ops for the first time, we create the cast node and update the cache.
   
   Another direction is removing the case and let this pass generate cast ops as many as it wants, and we run SimplifyExpr pass afterward to cancel back-to-back cast ops. I would actually recommend this approach due to its simple design if it doesn't hurt the final performance.
   

##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,356 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed precision for relay graphs. i.e. turn a graph into fp16 form.
+ */
+#include "to_mixed_precision.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// A function which maps CallNodes to their initial conversion color
+using ColorFunc = std::function<MixedTypeConversionCategory(const CallNode*)>;
+
+// A function which maps MIXED_PRECISION_ALWAYS CallNodes to wanted accumulation and output dtypes
+using OutputDtypeFunc = std::function<MixedPrecisionOpOutDType(const CallNode*)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+  const ColorFunc colorer;
+  const OutputDtypeFunc output_dtype_func;
+  const DataType mixed_precision_type;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs new_attrs = Attrs(call->attrs);
+    if (new_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = new_attrs.as<Conv1DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv1DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DeformableConv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DenseAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<BatchMatmulAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = new_attrs.as<InitOpAttrs>()) {
+        ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return new_attrs;
+  }
+
+  template <typename T>
+  void ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    T* mutable_attrs = const_cast<T*>(attrs);
+
+    DataType cur_type = (mutable_attrs->out_dtype);
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->out_dtype = accumulation_dtype;
+  }
+
+  template <typename T>
+  void ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    T* mutable_attrs = const_cast<T*>(attrs);
+    DataType cur_type = (mutable_attrs->dtype);
+
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->dtype = accumulation_dtype;
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only target mixed precision type elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == mixed_precision_type;
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsMixedPrecisionType(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = expr_dtype == wanted_dtype ? expr : Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;

Review comment:
       I reviewed the cache mechanism and I think I got the idea. Here is the example I went through:
   
   Consider the op `A (out: fp32, want: fp16)`, the cache will look like the following after processing A's output:
   ```
   (A, fp16): cast
   (cast, fp32): A
   ```
   
   Now consider the followed op `B`:
   Case 1. If `B` wants fp32, then like you mentioned before, we query `(cast, fp32)` and get `A`, so it becomes `A -> B`.
   Case 2. If `B` wants fp16, then we query `(cast, fp16)`, which is missed and a new entry `(cast, fp16): cast` is created and returned, so it becomes `A -> cast -> B`.
   
   This mechanism seems working well, and the cache size should be reasonable as it only keeps pointers. Two possible improvements:
   1. Apparently, the cache entry `(cast, fp16): cast` in the example is not necessary. I think we can simply return `expr` when `expr_dtype == wanted_dtype`?
   2. The created `cast` ops may be useless, such as the one in case 1. Is it possible to create this op lazily? For example, when casting the output, we only create a cache entry but don't really create the node. Once the entry is queried by the followed ops for the first time, we create the cast node and update the cache.
   
   Another direction is removing the cache and let this pass generate cast ops as many as it wants, and we run SimplifyExpr pass afterward to cancel back-to-back cast ops. I would actually recommend this approach due to its simple design if it doesn't hurt the final performance.
   

##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,356 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed precision for relay graphs. i.e. turn a graph into fp16 form.
+ */
+#include "to_mixed_precision.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// A function which maps CallNodes to their initial conversion color
+using ColorFunc = std::function<MixedTypeConversionCategory(const CallNode*)>;
+
+// A function which maps MIXED_PRECISION_ALWAYS CallNodes to wanted accumulation and output dtypes
+using OutputDtypeFunc = std::function<MixedPrecisionOpOutDType(const CallNode*)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+  const ColorFunc colorer;
+  const OutputDtypeFunc output_dtype_func;
+  const DataType mixed_precision_type;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs new_attrs = Attrs(call->attrs);
+    if (new_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = new_attrs.as<Conv1DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv1DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DeformableConv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DenseAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<BatchMatmulAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = new_attrs.as<InitOpAttrs>()) {
+        ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return new_attrs;
+  }
+
+  template <typename T>
+  void ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    T* mutable_attrs = const_cast<T*>(attrs);
+
+    DataType cur_type = (mutable_attrs->out_dtype);
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->out_dtype = accumulation_dtype;
+  }
+
+  template <typename T>
+  void ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    T* mutable_attrs = const_cast<T*>(attrs);
+    DataType cur_type = (mutable_attrs->dtype);
+
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->dtype = accumulation_dtype;
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only target mixed precision type elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == mixed_precision_type;
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsMixedPrecisionType(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = expr_dtype == wanted_dtype ? expr : Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;

Review comment:
       I reviewed the cache mechanism and I think I got the idea. Here is the example I went through:
   
   Consider the op `A (out: fp32, want: fp16)`, the cache will look like the following after processing A's output:
   ```
   (A, fp16): cast
   (cast, fp32): A
   ```
   
   Now consider the followed op `B`:
   Case 1. If `B` wants fp32, then like you mentioned before, we query `(cast, fp32)` and get `A`, so it becomes `A -> B`.
   Case 2. If `B` wants fp16, then we query `(cast, fp16)`, which is missed and a new entry `(cast, fp16): cast` is created and returned, so it becomes `A -> cast -> B`.
   
   This mechanism seems working well, and the cache size should be reasonable as it only keeps pointers. Two possible improvements:
   1. Apparently, the cache entry `(cast, fp16): cast` in the example is not necessary. I think we can simply return `expr` when `expr_dtype == wanted_dtype`?
   2. The created `cast` ops may be useless, such as the one in case 1. Is it possible to create this op lazily? For example, when casting the output, we only create a cache entry but don't really create the node. Once the entry is queried by the followed ops for the first time, we create the cast node and update the cache.
   
   Another direction I would actually recommend is removing the cache and letting this pass generate cast ops as many as it wants, and we run SimplifyExpr pass afterward to cancel back-to-back cast ops. IIUC, this should generate the same IR as the current pass, so it doesn't hurt the final performance (please correct me if I missed something).
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] anijain2305 commented on pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
anijain2305 commented on pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#issuecomment-856966217


   Can we get a few more initial reviews - @mbrookhart , @csullivan?
   
   @AndrewZhaoLuo I would also suggest to test a dynamic model like SSD or Mask-RCNN. Your current list of Object detection models involve Yolo which is static model. 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] MeJerry215 commented on pull request #8069: [Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass

Posted by GitBox <gi...@apache.org>.
MeJerry215 commented on pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#issuecomment-1002904713


   it seems like all conv matmul dense op cast to fp16 when using mix precision pass.
   ![image](https://user-images.githubusercontent.com/53092165/147730311-9c0567f3-a142-431e-9204-b9004dee6507.png)
   but i still have a question about why not cast weight to fp16 and remove all cast op.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on a change in pull request #8069: [Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r652211945



##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,409 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed floating point precision for relay graphs. i.e. turn a graph into fp16.
+ *
+ */
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
+// savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to
+// justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to
+// numerical reasons.
+enum MixedTypeConversionCategory : int {
+  MIXED_PRECISION_ALWAYS = 0,
+  MIXED_PRECISION_FOLLOW = 1,
+  MIXED_PRECISION_NEVER = 2
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// Return array is of type : [MixedTypeConversionCategory (int), String, String]
+// The fields are          : [ConversionCategory, accumulation_datatype, output_datatype]
+// Call is a call node, DataType is the mixed precision type
+using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc<Array<ObjectRef>(
+    const Call& call_node, const std::string& target_dtype_str)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+
+  // The target datatype we want to convert to e.g. FP16
+  const DataType mixed_precision_type;
+
+  // If false, throws a fatal error if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool ignore_missing_ops;
+
+  // If true, emits a warning if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool warn_missing_ops;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs cur_attrs = call->attrs;
+    if (cur_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = cur_attrs.as<Conv1DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv1DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DeformableConv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DenseAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<BatchMatmulAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = cur_attrs.as<InitOpAttrs>()) {
+        return ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return cur_attrs;
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    DataType cur_type = (attrs->out_dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->out_dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    DataType cur_type = (attrs->dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only target mixed precision type elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == mixed_precision_type;
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsMixedPrecisionType(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    if (expr_dtype == wanted_dtype) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;
+
+    // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
+    const ExprNode* new_expr_node = result.as<ExprNode>();
+    cast_nodes_cache[{new_expr_node, expr_dtype}] = expr;
+    return result;
+  }
+
+  Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
+    /* Helper for casting arguments to call_nodes handling all relevant cases. */
+    if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+      return CachedCast(expr, tensor_type->dtype, wanted_dtype);
+    } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
+        Expr tuple_element = GetField(expr, i);
+        Type tuple_element_dtype = (tuple_type->fields)[i];
+        Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(tuple_element);
+      }
+      return all_same ? expr : Tuple(new_expr);
+    } else {
+      LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
+      return expr;
+    }
+  }
+
+  std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args,
+                                                  const Array<Type>& cur_arg_types,
+                                                  const DataType& wanted_dtype) {
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    for (size_t i = 0; i < cur_args.size(); i++) {
+      Expr cur_arg = cur_args[i];
+      Type cur_arg_type = cur_arg_types[i];
+      Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+    }
+    return {new_args, new_arg_types};
+  }
+
+ public:
+  using MixedModeMutator::VisitExpr_;
+
+  explicit MixedPrecisionPass(DataType mixed_precision_type = DataType::Float(16),
+                              bool ignore_missing_ops = true, bool warn_missing_ops = true)
+      : MixedModeMutator(),
+        mixed_precision_type(mixed_precision_type),
+        ignore_missing_ops(ignore_missing_ops),
+        warn_missing_ops(warn_missing_ops) {
+    if (!mixed_precision_type.is_float() && !mixed_precision_type.is_bfloat16())
+      LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat16 got "
+                 << mixed_precision_type;
+  }
+
+  Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final {
+    const CallNode* post_call_node = post.as<CallNode>();
+    if (!post_call_node) {
+      LOG(FATAL) << "Expected a CallNode for the rewrite got " << post;
+    }
+
+    Expr cur_op = post_call_node->op;
+
+    // Get info on the operation being called:
+    // conversion category (int), accumulation dtype (str), output dtype (str)
+    MixedTypeConversionCategory initial_category;
+    DataType accumulation_dtype, output_dtype;
+    if (cur_op.as<FunctionNode>()) {
+      // Avoid messing with functions to avoid changing signature
+      initial_category = MIXED_PRECISION_NEVER;
+      accumulation_dtype = DataType::Float(32);
+      output_dtype = DataType::Float(32);
+    } else if (cur_op.as<OpNode>()) {
+      static auto attr_map =
+          Op::GetAttrMap<FTVMMixedPrecisionConversionType>("FTVMMixedPrecisionConversionType");
+      Op op = Downcast<Op>(cur_op);
+      if (attr_map.count(op)) {
+        // Calculate the conversion category and dtypes from registered attribute.
+        FTVMMixedPrecisionConversionType func = attr_map[op];
+        Array<ObjectRef> op_descriptor =
+            func(GetRef<Call>(pre_call_node), DLDataType2String(mixed_precision_type));
+
+        int64_t op_conversion_type = Downcast<Integer>(op_descriptor[0])->value;
+        initial_category = static_cast<MixedTypeConversionCategory>(op_conversion_type);
+        accumulation_dtype = DataType(String2DLDataType(Downcast<String>(op_descriptor[1])));
+        output_dtype = DataType(String2DLDataType(Downcast<String>(op_descriptor[2])));
+      } else {
+        if (!ignore_missing_ops) LOG(FATAL) << "Op " << op->name << " not in conversion lists!";
+        if (warn_missing_ops) LOG(WARNING) << "Op " << op->name << " not in conversion lists!";
+
+        // If not registered, by default assume is a generic FOLLOW operation.
+        initial_category = MIXED_PRECISION_FOLLOW;
+        accumulation_dtype = DataType::Float(16);
+        output_dtype = DataType::Float(16);
+      }
+    } else {
+      LOG(FATAL) << "Unsupported op type in CallNode: " << pre_call_node->op;
+    }
+
+    // First check if all the new mutated args are in lower precision form
+    Array<Type> cur_arg_types;
+    bool all_args_mixed_type_compatible = true;
+    for (Expr arg : post_call_node->args) {
+      Type cur_arg_type = GetType(arg);
+      cur_arg_types.push_back(cur_arg_type);
+
+      if (initial_category == MIXED_PRECISION_FOLLOW && all_args_mixed_type_compatible) {
+        // We can cast Vars and Constants to the right types so don't care about the types.
+        bool is_mixed_type_compatible = IsMixedPrecisionType(cur_arg_type, true) ||
+                                        arg->IsInstance<VarNode>() ||
+                                        arg->IsInstance<ConstantNode>();
+        all_args_mixed_type_compatible &= is_mixed_type_compatible;
+      }
+    }
+
+    // Determine the final category we want for conversion
+    MixedTypeConversionCategory final_category;
+    if (initial_category == MIXED_PRECISION_FOLLOW) {
+      final_category =
+          all_args_mixed_type_compatible ? MIXED_PRECISION_ALWAYS : MIXED_PRECISION_NEVER;
+    } else {
+      final_category = initial_category;
+    }
+
+    // Create the new arguments to the call.
+    DataType wanted_arg_dtypes =
+        final_category == MIXED_PRECISION_ALWAYS ? mixed_precision_type : DataType::Float(32);
+    auto call_args_and_types = CastAllArgs(post_call_node->args, cur_arg_types, wanted_arg_dtypes);
+    Array<Expr> new_args = call_args_and_types.first;
+    Array<Type> new_arg_types;
+
+    if (pre_call_node->op.as<FunctionNode>()) {
+      // Function Nodes don't store type info in the Call, it should be a []
+      new_arg_types = pre_call_node->type_args;
+    } else {
+      new_arg_types = call_args_and_types.second;
+    }
+
+    // Finally create the new attributes.
+    if (final_category == MIXED_PRECISION_ALWAYS) {
+      Attrs new_attrs = GetNewAttrs(pre_call_node, accumulation_dtype);
+      Expr output = Call(cur_op, new_args, new_attrs, new_arg_types, pre_call_node->span);
+      if (accumulation_dtype != output_dtype) {
+        output = CastArg(output, GetType(output), output_dtype);
+      }
+      return output;
+    }
+
+    return Call(cur_op, new_args, pre_call_node->attrs, new_arg_types, pre_call_node->span);
+  }
+
+  Expr VisitExpr_(const FunctionNode* func) final {
+    // Erase the ret_type annotation and let the normal pass recalculate
+    const_cast<FunctionNode*>(func)->ret_type = Type(nullptr);
+    return ExprMutator::VisitExpr_(func);
+  }
+
+  Expr VisitExpr_(const LetNode* op) final {
+    // First convert as much of the bound computation to lower precision as possible
+    Expr value = this->Mutate(op->value);
+
+    // Then rewrite the var type and associated expression
+    Var var = Downcast<Var>(this->Mutate(op->var));
+    VarNode* mutable_var = const_cast<VarNode*>((op->var).as<VarNode>());
+    mutable_var->type_annotation = GetType(value);
+    mutable_var->checked_type_ = mutable_var->type_annotation;
+
+    // Mutate body last as it may depend on previous results
+    Expr body = this->Mutate(op->body);
+    return Let(var, value, body, op->span);
+  }
+};
+
+Expr ToMixedPrecision(const Expr& expr, const DataType& mixed_precision_type,
+                      bool ignore_missing_ops, bool warn_missing_ops) {
+  MixedPrecisionPass converter =
+      MixedPrecisionPass(mixed_precision_type, ignore_missing_ops, warn_missing_ops);
+  auto result = converter.Mutate(expr);
+  return result;
+}
+
+namespace transform {
+
+Pass ToMixedPrecision(DataType mixed_precision_type, bool ignore_missing_ops,
+                      bool warn_missing_ops) {
+  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
+      [=](Function f, IRModule m, PassContext pc) {
+        return Downcast<Function>(
+            ToMixedPrecision(f, mixed_precision_type, ignore_missing_ops, warn_missing_ops));
+      };
+  return CreateFunctionPass(pass_func, 10, "ToMixedPrecision", {});

Review comment:
       Done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r649735781



##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,356 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed precision for relay graphs. i.e. turn a graph into fp16 form.
+ */
+#include "to_mixed_precision.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// A function which maps CallNodes to their initial conversion color
+using ColorFunc = std::function<MixedTypeConversionCategory(const CallNode*)>;
+
+// A function which maps MIXED_PRECISION_ALWAYS CallNodes to wanted accumulation and output dtypes
+using OutputDtypeFunc = std::function<MixedPrecisionOpOutDType(const CallNode*)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+  const ColorFunc colorer;
+  const OutputDtypeFunc output_dtype_func;
+  const DataType mixed_precision_type;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs new_attrs = Attrs(call->attrs);
+    if (new_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = new_attrs.as<Conv1DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv1DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DeformableConv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DenseAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<BatchMatmulAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = new_attrs.as<InitOpAttrs>()) {
+        ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return new_attrs;
+  }
+
+  template <typename T>
+  void ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    T* mutable_attrs = const_cast<T*>(attrs);

Review comment:
       Can we create and return a new attribute?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on a change in pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r648609224



##########
File path: src/relay/transforms/fp32_to_fp16.cc
##########
@@ -0,0 +1,337 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file fp32_to_fp16.cc
+ * \brief Rewrite a graph into an fp16 form.
+ */
+#include "fp32_to_fp16.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// A function which maps CallNodes to their initial conversion color
+using ColorFunc = std::function<FP16ConversionCategory(const CallNode*)>;
+
+// A function which maps green CallNodes to wanted accumulation and output dtypes
+using OutputDtypeFunc = std::function<FP16OpDType(const CallNode*)>;
+
+class AmpGraphCreator : public ExprMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+  const ColorFunc colorer;
+  const OutputDtypeFunc output_dtype_func;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs new_attrs = Attrs(call->attrs);
+    if (new_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = new_attrs.as<Conv1DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv1DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DeformableConv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DenseAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<BatchMatmulAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = new_attrs.as<InitOpAttrs>()) {
+        ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return new_attrs;
+  }
+
+  template <typename T>
+  void ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    T* mutable_attrs = const_cast<T*>(attrs);
+
+    DataType cur_type = (mutable_attrs->out_dtype);
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->out_dtype = accumulation_dtype;
+  }
+
+  template <typename T>
+  void ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    T* mutable_attrs = const_cast<T*>(attrs);
+    DataType cur_type = (mutable_attrs->dtype);
+
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->dtype = accumulation_dtype;
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsFP16Type(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only fp16 elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == DataType::Float(16);
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsFP16Type(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = expr_dtype == wanted_dtype ? expr : Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;
+
+    // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
+    const ExprNode* new_expr_node = result.as<ExprNode>();
+    cast_nodes_cache[{new_expr_node, expr_dtype}] = expr;
+    return result;
+  }
+
+  Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
+    /* Helper for casting arguments to call_nodes handling all relevant cases. */
+    if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+      return CachedCast(expr, tensor_type->dtype, wanted_dtype);
+    } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
+        Expr tuple_element = GetField(expr, i);
+        Type tuple_element_dtype = (tuple_type->fields)[i];
+        Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(tuple_element);
+      }
+      return all_same ? expr : Tuple(new_expr);
+    } else {
+      LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
+      return expr;
+    }
+  }
+
+  std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args,
+                                                  const Array<Type>& cur_arg_types,
+                                                  const DataType& wanted_dtype) {
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    for (size_t i = 0; i < cur_args.size(); i++) {
+      Expr cur_arg = cur_args[i];
+      Type cur_arg_type = cur_arg_types[i];
+      Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+    }
+    return {new_args, new_arg_types};
+  }
+
+ public:
+  explicit AmpGraphCreator(ColorFunc colorer, OutputDtypeFunc output_dtype_func)
+      : ExprMutator(), colorer(colorer), output_dtype_func(output_dtype_func) {}
+
+  Expr VisitExpr_(const CallNode* call_node) {
+    FP16ConversionCategory initial_color = colorer(call_node);
+    auto new_op = this->Mutate(call_node->op);
+
+    // Mutate arguments to FP16 form first if possible and keep track of whether all floating point
+    // tensors are in FP16 form already. This is useful for propagating color.
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    bool all_args_fp16_compatible = true;
+    for (Expr arg : call_node->args) {
+      Expr new_arg = this->Mutate(arg);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+
+      if (all_args_fp16_compatible) {
+        // We can cast Vars and Constants to the right types so don't care about the types.
+        bool is_fp16_compatible = IsFP16Type(new_arg_type, true) || arg->IsInstance<VarNode>() ||
+                                  arg->IsInstance<ConstantNode>();
+        all_args_fp16_compatible &= is_fp16_compatible;
+      }
+    }
+
+    // Determine the final color.
+    FP16ConversionCategory final_color;
+    if (initial_color == GRAY) {
+      final_color = all_args_fp16_compatible ? GREEN : RED;

Review comment:
       A contrived example with concat.
   
   We have two branches whose outputs are fed into concat.
   
   The first branch has a RED operation and returns an FP32 tensor.
   The second branch returns an FP16 tensor.
   
   Now that I say this, it might be better to be a bit smarter about GRAY ops when we have heterogeneous floating point types coming in.
   
   E.g. let's say we had a concat with 10 fp16 args and 1 fp32 arg. It would be wasteful to default convert everything to fp32 and set the color as RED in this case. 
   
   I will change this so the number of fp16/fp32 args are taken into account. If there is a majority of fp16 or a tie we color GREEN else we color RED. Sound good?

##########
File path: src/relay/transforms/fp32_to_fp16.cc
##########
@@ -0,0 +1,337 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file fp32_to_fp16.cc
+ * \brief Rewrite a graph into an fp16 form.
+ */
+#include "fp32_to_fp16.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// A function which maps CallNodes to their initial conversion color
+using ColorFunc = std::function<FP16ConversionCategory(const CallNode*)>;
+
+// A function which maps green CallNodes to wanted accumulation and output dtypes
+using OutputDtypeFunc = std::function<FP16OpDType(const CallNode*)>;
+
+class AmpGraphCreator : public ExprMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+  const ColorFunc colorer;
+  const OutputDtypeFunc output_dtype_func;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs new_attrs = Attrs(call->attrs);
+    if (new_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = new_attrs.as<Conv1DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv1DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DeformableConv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DenseAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<BatchMatmulAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = new_attrs.as<InitOpAttrs>()) {
+        ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return new_attrs;
+  }
+
+  template <typename T>
+  void ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    T* mutable_attrs = const_cast<T*>(attrs);
+
+    DataType cur_type = (mutable_attrs->out_dtype);
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->out_dtype = accumulation_dtype;
+  }
+
+  template <typename T>
+  void ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    T* mutable_attrs = const_cast<T*>(attrs);
+    DataType cur_type = (mutable_attrs->dtype);
+
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->dtype = accumulation_dtype;
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsFP16Type(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only fp16 elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == DataType::Float(16);
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsFP16Type(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = expr_dtype == wanted_dtype ? expr : Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;
+
+    // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
+    const ExprNode* new_expr_node = result.as<ExprNode>();
+    cast_nodes_cache[{new_expr_node, expr_dtype}] = expr;
+    return result;
+  }
+
+  Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
+    /* Helper for casting arguments to call_nodes handling all relevant cases. */
+    if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+      return CachedCast(expr, tensor_type->dtype, wanted_dtype);
+    } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
+        Expr tuple_element = GetField(expr, i);
+        Type tuple_element_dtype = (tuple_type->fields)[i];
+        Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(tuple_element);
+      }
+      return all_same ? expr : Tuple(new_expr);
+    } else {
+      LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
+      return expr;
+    }
+  }
+
+  std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args,
+                                                  const Array<Type>& cur_arg_types,
+                                                  const DataType& wanted_dtype) {
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    for (size_t i = 0; i < cur_args.size(); i++) {
+      Expr cur_arg = cur_args[i];
+      Type cur_arg_type = cur_arg_types[i];
+      Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+    }
+    return {new_args, new_arg_types};
+  }
+
+ public:
+  explicit AmpGraphCreator(ColorFunc colorer, OutputDtypeFunc output_dtype_func)
+      : ExprMutator(), colorer(colorer), output_dtype_func(output_dtype_func) {}
+
+  Expr VisitExpr_(const CallNode* call_node) {
+    FP16ConversionCategory initial_color = colorer(call_node);
+    auto new_op = this->Mutate(call_node->op);
+
+    // Mutate arguments to FP16 form first if possible and keep track of whether all floating point
+    // tensors are in FP16 form already. This is useful for propagating color.
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    bool all_args_fp16_compatible = true;
+    for (Expr arg : call_node->args) {
+      Expr new_arg = this->Mutate(arg);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+
+      if (all_args_fp16_compatible) {
+        // We can cast Vars and Constants to the right types so don't care about the types.
+        bool is_fp16_compatible = IsFP16Type(new_arg_type, true) || arg->IsInstance<VarNode>() ||
+                                  arg->IsInstance<ConstantNode>();
+        all_args_fp16_compatible &= is_fp16_compatible;
+      }
+    }
+
+    // Determine the final color.
+    FP16ConversionCategory final_color;
+    if (initial_color == GRAY) {
+      final_color = all_args_fp16_compatible ? GREEN : RED;

Review comment:
       A contrived example with concat.
   
   We have two branches whose outputs are fed into concat.
   
   The first branch has a RED operation and returns an FP32 tensor.
   The second branch returns an FP16 tensor.
   
   Now that I say this, it might be better to be a bit smarter about GRAY ops when we have heterogeneous floating point types coming in.
   
   E.g. let's say we had a concat with 10 fp16 args and 1 fp32 arg. It would be wasteful to default convert everything to fp32 and set the color as RED in this case. 
   
   I will change this so the number of fp16/fp32 args are taken into account. If there is a majority of fp16 or a tie we color GREEN else we color RED. Thoughts?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on a change in pull request #8069: [Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r652190632



##########
File path: python/tvm/relay/transform/transform.py
##########
@@ -1199,3 +1198,18 @@ def FakeQuantizationToInteger():
         The registered SimplifyExpr pass.
     """
     return _ffi_api.FakeQuantizationToInteger()
+
+
+def ToMixedPrecision(
+    mixed_precision_type="float16", ignore_missing_ops=True, warn_missing_ops=True
+):
+    """
+    Automatic mixed precision rewriter. Rewrite an FP32 relay graph into a version
+    where as many operations as possible are in the target mixed_precision_type.

Review comment:
       Done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on pull request #8069: [Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#issuecomment-865193717


   Added some tracking issues for CUDA and Vulkan:
   https://github.com/apache/tvm/issues/8295 
   https://github.com/apache/tvm/issues/8294


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on a change in pull request #8069: [Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r652211415



##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,409 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed floating point precision for relay graphs. i.e. turn a graph into fp16.
+ *
+ */
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
+// savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to
+// justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to
+// numerical reasons.
+enum MixedTypeConversionCategory : int {
+  MIXED_PRECISION_ALWAYS = 0,
+  MIXED_PRECISION_FOLLOW = 1,
+  MIXED_PRECISION_NEVER = 2
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// Return array is of type : [MixedTypeConversionCategory (int), String, String]
+// The fields are          : [ConversionCategory, accumulation_datatype, output_datatype]
+// Call is a call node, DataType is the mixed precision type
+using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc<Array<ObjectRef>(
+    const Call& call_node, const std::string& target_dtype_str)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+
+  // The target datatype we want to convert to e.g. FP16
+  const DataType mixed_precision_type;
+
+  // If false, throws a fatal error if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool ignore_missing_ops;
+
+  // If true, emits a warning if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool warn_missing_ops;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs cur_attrs = call->attrs;
+    if (cur_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = cur_attrs.as<Conv1DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv1DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DeformableConv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DenseAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<BatchMatmulAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = cur_attrs.as<InitOpAttrs>()) {
+        return ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return cur_attrs;
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    DataType cur_type = (attrs->out_dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->out_dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    DataType cur_type = (attrs->dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only target mixed precision type elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == mixed_precision_type;
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsMixedPrecisionType(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    if (expr_dtype == wanted_dtype) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;
+
+    // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
+    const ExprNode* new_expr_node = result.as<ExprNode>();
+    cast_nodes_cache[{new_expr_node, expr_dtype}] = expr;
+    return result;
+  }
+
+  Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
+    /* Helper for casting arguments to call_nodes handling all relevant cases. */
+    if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+      return CachedCast(expr, tensor_type->dtype, wanted_dtype);
+    } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
+        Expr tuple_element = GetField(expr, i);
+        Type tuple_element_dtype = (tuple_type->fields)[i];
+        Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(tuple_element);
+      }
+      return all_same ? expr : Tuple(new_expr);
+    } else {
+      LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
+      return expr;
+    }
+  }
+
+  std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args,
+                                                  const Array<Type>& cur_arg_types,
+                                                  const DataType& wanted_dtype) {
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    for (size_t i = 0; i < cur_args.size(); i++) {
+      Expr cur_arg = cur_args[i];
+      Type cur_arg_type = cur_arg_types[i];
+      Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+    }
+    return {new_args, new_arg_types};
+  }
+
+ public:
+  using MixedModeMutator::VisitExpr_;
+
+  explicit MixedPrecisionPass(DataType mixed_precision_type = DataType::Float(16),
+                              bool ignore_missing_ops = true, bool warn_missing_ops = true)
+      : MixedModeMutator(),
+        mixed_precision_type(mixed_precision_type),
+        ignore_missing_ops(ignore_missing_ops),
+        warn_missing_ops(warn_missing_ops) {
+    if (!mixed_precision_type.is_float() && !mixed_precision_type.is_bfloat16())
+      LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat16 got "
+                 << mixed_precision_type;
+  }
+
+  Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final {
+    const CallNode* post_call_node = post.as<CallNode>();
+    if (!post_call_node) {
+      LOG(FATAL) << "Expected a CallNode for the rewrite got " << post;
+    }
+
+    Expr cur_op = post_call_node->op;
+
+    // Get info on the operation being called:
+    // conversion category (int), accumulation dtype (str), output dtype (str)
+    MixedTypeConversionCategory initial_category;
+    DataType accumulation_dtype, output_dtype;
+    if (cur_op.as<FunctionNode>()) {
+      // Avoid messing with functions to avoid changing signature
+      initial_category = MIXED_PRECISION_NEVER;
+      accumulation_dtype = DataType::Float(32);
+      output_dtype = DataType::Float(32);
+    } else if (cur_op.as<OpNode>()) {
+      static auto attr_map =
+          Op::GetAttrMap<FTVMMixedPrecisionConversionType>("FTVMMixedPrecisionConversionType");
+      Op op = Downcast<Op>(cur_op);
+      if (attr_map.count(op)) {
+        // Calculate the conversion category and dtypes from registered attribute.
+        FTVMMixedPrecisionConversionType func = attr_map[op];
+        Array<ObjectRef> op_descriptor =
+            func(GetRef<Call>(pre_call_node), DLDataType2String(mixed_precision_type));
+
+        int64_t op_conversion_type = Downcast<Integer>(op_descriptor[0])->value;
+        initial_category = static_cast<MixedTypeConversionCategory>(op_conversion_type);
+        accumulation_dtype = DataType(String2DLDataType(Downcast<String>(op_descriptor[1])));
+        output_dtype = DataType(String2DLDataType(Downcast<String>(op_descriptor[2])));
+      } else {
+        if (!ignore_missing_ops) LOG(FATAL) << "Op " << op->name << " not in conversion lists!";
+        if (warn_missing_ops) LOG(WARNING) << "Op " << op->name << " not in conversion lists!";

Review comment:
       Done, it's now a single flag `missing_op_mode` with behavior suggested




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#issuecomment-855128163


   Hey Animesh, it'll be ready for review soon. Probably by Monday morning (PST time).
   
   There's still some misc. improvements that should be made but I've decided to push those down for later PR's.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on pull request #8069: [Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#issuecomment-865193717


   Added some tracking issues for CUDA and Vulkan:
   https://github.com/apache/tvm/issues/8295 
   https://github.com/apache/tvm/issues/8294


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#issuecomment-861727219


   @anijain2305 
   @masahi 
   @comaniac 
   @mbrookhart 
   @csullivan 
   
   PTAL. I believe I've addressed all the major points.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on a change in pull request #8069: [Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r652142144



##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,409 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed floating point precision for relay graphs. i.e. turn a graph into fp16.
+ *
+ */
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
+// savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to
+// justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to
+// numerical reasons.
+enum MixedTypeConversionCategory : int {
+  MIXED_PRECISION_ALWAYS = 0,
+  MIXED_PRECISION_FOLLOW = 1,
+  MIXED_PRECISION_NEVER = 2
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// Return array is of type : [MixedTypeConversionCategory (int), String, String]
+// The fields are          : [ConversionCategory, accumulation_datatype, output_datatype]
+// Call is a call node, DataType is the mixed precision type
+using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc<Array<ObjectRef>(
+    const Call& call_node, const std::string& target_dtype_str)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+
+  // The target datatype we want to convert to e.g. FP16
+  const DataType mixed_precision_type;
+
+  // If false, throws a fatal error if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool ignore_missing_ops;
+
+  // If true, emits a warning if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool warn_missing_ops;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs cur_attrs = call->attrs;
+    if (cur_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = cur_attrs.as<Conv1DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv1DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DeformableConv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DenseAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<BatchMatmulAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = cur_attrs.as<InitOpAttrs>()) {
+        return ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return cur_attrs;
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    DataType cur_type = (attrs->out_dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->out_dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    DataType cur_type = (attrs->dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }

Review comment:
       Hey Matthew, do you have an example model I could use to understand this problem a little bit more?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#issuecomment-856160404


   This is ready for review


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on a change in pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r648649618



##########
File path: python/tvm/relay/transform/transform.py
##########
@@ -1199,3 +1198,20 @@ def FakeQuantizationToInteger():
         The registered SimplifyExpr pass.
     """
     return _ffi_api.FakeQuantizationToInteger()
+
+
+def AMPRewrite():

Review comment:
       I'll go with autocast.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] comaniac commented on a change in pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r648618186



##########
File path: src/relay/transforms/fp32_to_fp16.cc
##########
@@ -0,0 +1,337 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file fp32_to_fp16.cc
+ * \brief Rewrite a graph into an fp16 form.
+ */
+#include "fp32_to_fp16.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// A function which maps CallNodes to their initial conversion color
+using ColorFunc = std::function<FP16ConversionCategory(const CallNode*)>;
+
+// A function which maps green CallNodes to wanted accumulation and output dtypes
+using OutputDtypeFunc = std::function<FP16OpDType(const CallNode*)>;
+
+class AmpGraphCreator : public ExprMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+  const ColorFunc colorer;
+  const OutputDtypeFunc output_dtype_func;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs new_attrs = Attrs(call->attrs);
+    if (new_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = new_attrs.as<Conv1DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv1DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DeformableConv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DenseAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<BatchMatmulAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = new_attrs.as<InitOpAttrs>()) {
+        ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return new_attrs;
+  }
+
+  template <typename T>
+  void ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    T* mutable_attrs = const_cast<T*>(attrs);
+
+    DataType cur_type = (mutable_attrs->out_dtype);
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->out_dtype = accumulation_dtype;
+  }
+
+  template <typename T>
+  void ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    T* mutable_attrs = const_cast<T*>(attrs);
+    DataType cur_type = (mutable_attrs->dtype);
+
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->dtype = accumulation_dtype;
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsFP16Type(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only fp16 elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == DataType::Float(16);
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsFP16Type(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = expr_dtype == wanted_dtype ? expr : Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;
+
+    // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
+    const ExprNode* new_expr_node = result.as<ExprNode>();
+    cast_nodes_cache[{new_expr_node, expr_dtype}] = expr;
+    return result;
+  }
+
+  Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
+    /* Helper for casting arguments to call_nodes handling all relevant cases. */
+    if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+      return CachedCast(expr, tensor_type->dtype, wanted_dtype);
+    } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
+        Expr tuple_element = GetField(expr, i);
+        Type tuple_element_dtype = (tuple_type->fields)[i];
+        Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(tuple_element);
+      }
+      return all_same ? expr : Tuple(new_expr);
+    } else {
+      LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
+      return expr;
+    }
+  }
+
+  std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args,
+                                                  const Array<Type>& cur_arg_types,
+                                                  const DataType& wanted_dtype) {
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    for (size_t i = 0; i < cur_args.size(); i++) {
+      Expr cur_arg = cur_args[i];
+      Type cur_arg_type = cur_arg_types[i];
+      Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+    }
+    return {new_args, new_arg_types};
+  }
+
+ public:
+  explicit AmpGraphCreator(ColorFunc colorer, OutputDtypeFunc output_dtype_func)
+      : ExprMutator(), colorer(colorer), output_dtype_func(output_dtype_func) {}
+
+  Expr VisitExpr_(const CallNode* call_node) {
+    FP16ConversionCategory initial_color = colorer(call_node);
+    auto new_op = this->Mutate(call_node->op);
+
+    // Mutate arguments to FP16 form first if possible and keep track of whether all floating point
+    // tensors are in FP16 form already. This is useful for propagating color.
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    bool all_args_fp16_compatible = true;
+    for (Expr arg : call_node->args) {
+      Expr new_arg = this->Mutate(arg);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+
+      if (all_args_fp16_compatible) {
+        // We can cast Vars and Constants to the right types so don't care about the types.
+        bool is_fp16_compatible = IsFP16Type(new_arg_type, true) || arg->IsInstance<VarNode>() ||
+                                  arg->IsInstance<ConstantNode>();
+        all_args_fp16_compatible &= is_fp16_compatible;
+      }
+    }
+
+    // Determine the final color.
+    FP16ConversionCategory final_color;
+    if (initial_color == GRAY) {
+      final_color = all_args_fp16_compatible ? GREEN : RED;

Review comment:
       The workaround sounds fine to me. Again I'd suggest putting these op-specific heuristics to op attribute instead of this pass.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] csullivan commented on pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
csullivan commented on pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#issuecomment-857915269


   Thanks for this great PR! Would it be too much to ask for AMPRewrite and corresponding infra to support mixed precision with generic reduced precision floating point types? I notice the main assumption is to be downcasting to float16, though TVM has support for other reduced precision fp types for which mixed precision is useful e.g. float32 + bfloat16, as well as possible user defined floating point types.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on a change in pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r651148846



##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,356 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed precision for relay graphs. i.e. turn a graph into fp16 form.
+ */
+#include "to_mixed_precision.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// A function which maps CallNodes to their initial conversion color
+using ColorFunc = std::function<MixedTypeConversionCategory(const CallNode*)>;
+
+// A function which maps MIXED_PRECISION_ALWAYS CallNodes to wanted accumulation and output dtypes
+using OutputDtypeFunc = std::function<MixedPrecisionOpOutDType(const CallNode*)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+  const ColorFunc colorer;
+  const OutputDtypeFunc output_dtype_func;
+  const DataType mixed_precision_type;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs new_attrs = Attrs(call->attrs);
+    if (new_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = new_attrs.as<Conv1DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv1DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DeformableConv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DenseAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<BatchMatmulAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = new_attrs.as<InitOpAttrs>()) {
+        ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return new_attrs;
+  }
+
+  template <typename T>
+  void ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    T* mutable_attrs = const_cast<T*>(attrs);
+
+    DataType cur_type = (mutable_attrs->out_dtype);
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->out_dtype = accumulation_dtype;
+  }
+
+  template <typename T>
+  void ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    T* mutable_attrs = const_cast<T*>(attrs);
+    DataType cur_type = (mutable_attrs->dtype);
+
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->dtype = accumulation_dtype;
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only target mixed precision type elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == mixed_precision_type;
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsMixedPrecisionType(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = expr_dtype == wanted_dtype ? expr : Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;

Review comment:
       1. Right now that is what functionally happens with this line 
   `Expr result = expr_dtype == wanted_dtype ? expr : Cast(expr, wanted_dtype);` 
   
   It still creates a cache entry though so I reorganized it to be clearer and not insert into the cache when `expr_dtype == wanted_dtype`
   
   2. Hmm I believe creating the op lazily will not have any benefit. This is because there aren't any useless casts e.g. refer to 1.
   
   The idea of having another pass handle back to back casts is appealing as the tool can be used in many other situations. The main concern I have is about correctness, e.g. does it handle weird edge cases well? I'll take a closer look at the existing PR and think a little more about this.
   
   I do agree that this is a better direction to go however and will refactor the pass when a sufficiently correct cast-folding pass exists and is checked into main.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on a change in pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r648649618



##########
File path: python/tvm/relay/transform/transform.py
##########
@@ -1199,3 +1198,20 @@ def FakeQuantizationToInteger():
         The registered SimplifyExpr pass.
     """
     return _ffi_api.FakeQuantizationToInteger()
+
+
+def AMPRewrite():

Review comment:
       Hmm, I would prefer `ToMixedPrecision` still if that is fine with you.
   
   The example you list only works for me because it exists under the amp namespace.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] comaniac commented on pull request #8069: [Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass

Posted by GitBox <gi...@apache.org>.
comaniac commented on pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#issuecomment-878686422


   Yes. Please go to the discuss forum. You can refer to this PR and tag relevant people in the post.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on a change in pull request #8069: [Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r653826430



##########
File path: python/tvm/relay/transform/mixed_precision.py
##########
@@ -0,0 +1,177 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=line-too-long,unused-argument
+"""Default behavior for ops in mixed_precision pass. Import this file to use."""
+from typing import List
+
+from tvm import relay
+from tvm.relay.op import register_mixed_precision_conversion
+
+# MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
+# savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to
+# justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to
+# numerical reasons.
+MIXED_PRECISION_ALWAYS = 0
+MIXED_PRECISION_FOLLOW = 1
+MIXED_PRECISION_NEVER = 2
+
+# Default lists inspired from TF's classifications:
+# github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h
+# They have a bias toward Nvidia Tensor Cores so modify lists per your hardware choice.
+DEFAULT_ALWAYS_LIST = [
+    "nn.conv1d",
+    "nn.conv2d",
+    "nn.conv3d",
+    "nn.conv1d_transpose",
+    "nn.conv2d_transpose",
+    "nn.conv3d_transpose",
+    "nn.dense",
+    # "nn.batch_matmul", # Handled by a special case
+]
+DEFAULT_FOLLOW_LIST = [
+    # These ops add new data or change shape
+    "nn.pad",
+    "nn.batch_flatten",
+    "concatenate",
+    "zeros",
+    "split",
+    "squeeze",
+    "transpose",
+    "expand_dims",
+    "reshape",
+    "dyn.reshape",
+    "broadcast_to_like",
+    "dyn.broadcast_to",
+    "strided_slice",
+    "dyn.strided_slice",
+    "take",
+    "argwhere",
+    "where",
+    "tile",
+    "dyn.tile",
+    "scatter",
+    "full",
+    "dyn.full",
+    # Comparison
+    "less",
+    "greater",
+    "less_equal",
+    "greater_equal",
+    # By definition copy and cast will depend on inputs for output.
+    "copy",
+    "cast",
+    "cast_like",
+    # Simple arithmetic
+    "add",
+    "subtract",
+    "multiply",
+    "divide",
+    "nn.bias_add",
+    "nn.batch_norm",
+    "sum",
+    "mean",
+    "sqrt",
+    "shape_of",
+    # Simple activations
+    "max",
+    "min",
+    "maximum",
+    "minimum",
+    "nn.relu",
+    "nn.leaky_relu",
+    "nn.prelu",
+    "nn.dropout",
+    # Complicated activations which saturate in a narrow range
+    "sigmoid",
+    "tanh",
+    # Pooling operations
+    "nn.max_pool1d",
+    "nn.max_pool2d",
+    "nn.max_pool3d",
+    "nn.avg_pool1d",
+    "nn.avg_pool2d",
+    "nn.avg_pool3d",
+    # "nn.global_max_pool1d", # does not exist yet
+    "nn.global_max_pool2d",
+    # "nn.global_max_pool3d", # does not exist yet
+    # "nn.global_avg_pool1d", # does not exist yet
+    "nn.global_avg_pool2d",
+    # "nn.global_avg_pool3d", # does not exist yet
+    "nn.adaptive_max_pool1d",
+    "nn.adaptive_max_pool2d",
+    "nn.adaptive_max_pool3d",
+    "nn.adaptive_avg_pool1d",
+    "nn.adaptive_avg_pool2d",
+    "nn.adaptive_avg_pool3d",
+]
+DEFAULT_NEVER_LIST = [
+    # In general if |f(x)| >> |x| for expected inputs then put the op here.
+    "exp",
+    "power",
+    "nn.cross_entropy",
+    "nn.cross_entropy_with_logits",
+    "nn.softmax",
+    "nn.l2_normalize",
+    # Error function doesn't seem to be able to be lowered into fp16 version in llvm.
+    # Move to follow list when it does.
+    "erf",
+]
+
+
+# Returns a decorator which registers for every given op, the function under FTVMMixedPrecisionConversionType
+def register_func_to_op_list(list_ops):
+    def decorator(func):
+        for op_name in list_ops:
+            register_mixed_precision_conversion(op_name, func=func)
+
+    return decorator
+
+
+def get_generic_out_dtypes(call_node: relay.Call, mixed_precision_type: str) -> List[str]:
+    # Assume support accumulation dtypes <---> has out_dtype attr

Review comment:
       Right now there isn't a good way to tell which ops have "accumulation" dtypes. 
   
   This is the simplest method I could think of. There is discussion on the discuss thread about making it easier to tell which ops support heterogenous accumulators.

##########
File path: python/tvm/relay/transform/mixed_precision.py
##########
@@ -0,0 +1,177 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=line-too-long,unused-argument
+"""Default behavior for ops in mixed_precision pass. Import this file to use."""
+from typing import List
+
+from tvm import relay
+from tvm.relay.op import register_mixed_precision_conversion
+
+# MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
+# savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to
+# justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to
+# numerical reasons.
+MIXED_PRECISION_ALWAYS = 0
+MIXED_PRECISION_FOLLOW = 1
+MIXED_PRECISION_NEVER = 2
+
+# Default lists inspired from TF's classifications:
+# github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h
+# They have a bias toward Nvidia Tensor Cores so modify lists per your hardware choice.
+DEFAULT_ALWAYS_LIST = [
+    "nn.conv1d",
+    "nn.conv2d",
+    "nn.conv3d",
+    "nn.conv1d_transpose",
+    "nn.conv2d_transpose",
+    "nn.conv3d_transpose",
+    "nn.dense",
+    # "nn.batch_matmul", # Handled by a special case
+]
+DEFAULT_FOLLOW_LIST = [
+    # These ops add new data or change shape
+    "nn.pad",
+    "nn.batch_flatten",
+    "concatenate",
+    "zeros",
+    "split",
+    "squeeze",
+    "transpose",
+    "expand_dims",
+    "reshape",
+    "dyn.reshape",
+    "broadcast_to_like",
+    "dyn.broadcast_to",
+    "strided_slice",
+    "dyn.strided_slice",
+    "take",
+    "argwhere",
+    "where",
+    "tile",
+    "dyn.tile",
+    "scatter",
+    "full",
+    "dyn.full",
+    # Comparison
+    "less",
+    "greater",
+    "less_equal",
+    "greater_equal",
+    # By definition copy and cast will depend on inputs for output.
+    "copy",
+    "cast",
+    "cast_like",
+    # Simple arithmetic
+    "add",
+    "subtract",
+    "multiply",
+    "divide",
+    "nn.bias_add",
+    "nn.batch_norm",
+    "sum",
+    "mean",
+    "sqrt",
+    "shape_of",
+    # Simple activations
+    "max",
+    "min",
+    "maximum",
+    "minimum",
+    "nn.relu",
+    "nn.leaky_relu",
+    "nn.prelu",
+    "nn.dropout",
+    # Complicated activations which saturate in a narrow range
+    "sigmoid",
+    "tanh",
+    # Pooling operations
+    "nn.max_pool1d",
+    "nn.max_pool2d",
+    "nn.max_pool3d",
+    "nn.avg_pool1d",
+    "nn.avg_pool2d",
+    "nn.avg_pool3d",
+    # "nn.global_max_pool1d", # does not exist yet
+    "nn.global_max_pool2d",
+    # "nn.global_max_pool3d", # does not exist yet
+    # "nn.global_avg_pool1d", # does not exist yet
+    "nn.global_avg_pool2d",
+    # "nn.global_avg_pool3d", # does not exist yet
+    "nn.adaptive_max_pool1d",
+    "nn.adaptive_max_pool2d",
+    "nn.adaptive_max_pool3d",
+    "nn.adaptive_avg_pool1d",
+    "nn.adaptive_avg_pool2d",
+    "nn.adaptive_avg_pool3d",
+]
+DEFAULT_NEVER_LIST = [
+    # In general if |f(x)| >> |x| for expected inputs then put the op here.
+    "exp",
+    "power",
+    "nn.cross_entropy",
+    "nn.cross_entropy_with_logits",
+    "nn.softmax",
+    "nn.l2_normalize",
+    # Error function doesn't seem to be able to be lowered into fp16 version in llvm.
+    # Move to follow list when it does.
+    "erf",
+]
+
+
+# Returns a decorator which registers for every given op, the function under FTVMMixedPrecisionConversionType
+def register_func_to_op_list(list_ops):
+    def decorator(func):
+        for op_name in list_ops:
+            register_mixed_precision_conversion(op_name, func=func)
+
+    return decorator
+
+
+def get_generic_out_dtypes(call_node: relay.Call, mixed_precision_type: str) -> List[str]:
+    # Assume support accumulation dtypes <---> has out_dtype attr

Review comment:
       Done

##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,409 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed floating point precision for relay graphs. i.e. turn a graph into fp16.
+ *
+ */
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
+// savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to
+// justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to
+// numerical reasons.
+enum MixedTypeConversionCategory : int {
+  MIXED_PRECISION_ALWAYS = 0,
+  MIXED_PRECISION_FOLLOW = 1,
+  MIXED_PRECISION_NEVER = 2
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// Return array is of type : [MixedTypeConversionCategory (int), String, String]
+// The fields are          : [ConversionCategory, accumulation_datatype, output_datatype]
+// Call is a call node, DataType is the mixed precision type
+using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc<Array<ObjectRef>(
+    const Call& call_node, const std::string& target_dtype_str)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+
+  // The target datatype we want to convert to e.g. FP16
+  const DataType mixed_precision_type;
+
+  // If false, throws a fatal error if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool ignore_missing_ops;
+
+  // If true, emits a warning if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool warn_missing_ops;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs cur_attrs = call->attrs;
+    if (cur_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = cur_attrs.as<Conv1DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv1DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DeformableConv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DenseAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<BatchMatmulAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = cur_attrs.as<InitOpAttrs>()) {
+        return ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return cur_attrs;
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    DataType cur_type = (attrs->out_dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->out_dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    DataType cur_type = (attrs->dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }

Review comment:
       Yeah so the problems fundamentally have to do with me not thinking about and handling algebraic data types. I've added a warning and the pass should fail with an appropriate error message>
   
   I'm going to push support for this down the line for a future PR. Most models don't use these features I believe.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo edited a comment on pull request #8069: [Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo edited a comment on pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#issuecomment-863603673


   @mbrookhart PTAL. I'm going to push ADT support down to a future PR.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] comaniac commented on a change in pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r650300650



##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,356 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed precision for relay graphs. i.e. turn a graph into fp16 form.
+ */
+#include "to_mixed_precision.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// A function which maps CallNodes to their initial conversion color
+using ColorFunc = std::function<MixedTypeConversionCategory(const CallNode*)>;
+
+// A function which maps MIXED_PRECISION_ALWAYS CallNodes to wanted accumulation and output dtypes
+using OutputDtypeFunc = std::function<MixedPrecisionOpOutDType(const CallNode*)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+  const ColorFunc colorer;
+  const OutputDtypeFunc output_dtype_func;
+  const DataType mixed_precision_type;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs new_attrs = Attrs(call->attrs);
+    if (new_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = new_attrs.as<Conv1DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv1DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DeformableConv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DenseAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<BatchMatmulAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = new_attrs.as<InitOpAttrs>()) {
+        ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return new_attrs;
+  }
+
+  template <typename T>
+  void ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    T* mutable_attrs = const_cast<T*>(attrs);
+
+    DataType cur_type = (mutable_attrs->out_dtype);
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->out_dtype = accumulation_dtype;
+  }
+
+  template <typename T>
+  void ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    T* mutable_attrs = const_cast<T*>(attrs);
+    DataType cur_type = (mutable_attrs->dtype);
+
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->dtype = accumulation_dtype;
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only target mixed precision type elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == mixed_precision_type;
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsMixedPrecisionType(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = expr_dtype == wanted_dtype ? expr : Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;

Review comment:
       I reviewed the cache mechanism and I think I got the idea. Here is the example I went through:
   
   Consider the op `A (out: fp32, want: fp16)`, the cache will look like the following after processing A's output:
   ```
   (A, fp16): cast
   (cast, fp32): A
   ```
   
   Now consider the followed op `B`:
   Case 1. If `B` wants fp32, then like you mentioned before, we query `(cast, fp32)` and get `A`, so it becomes `A -> B`.
   Case 2. If `B` wants fp16, then we query `(cast, fp16)`, which is missed and a new entry `(cast, fp16): cast` is created and returned, so it becomes `A -> cast -> B`.
   
   This mechanism seems working well, and the cache size should be reasonable as it only keeps pointers. Two possible improvements:
   1. Apparently, the cache entry `(cast, fp16): cast` in the example is not necessary. I think we can simply return `expr` when `expr_dtype == wanted_dtype`?
   2. The created `cast` ops may be useless, such as the one in case 1. Is it possible to create this op lazily? For example, when casting the output, we only create a cache entry but don't really create the node. Once the entry is queried by the followed ops for the first time, we create the cast node and update the cache.
   
   Another direction I would actually recommend is removing the cache and letting this pass generate cast ops as many as it wants, and we run SimplifyExpr pass afterward to cancel back-to-back cast ops (ref: https://github.com/apache/tvm/pull/8081). IIUC, this should generate the same IR as the current pass, so it doesn't hurt the final performance (please correct me if I missed something).
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on a change in pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r651431025



##########
File path: src/relay/transforms/fp32_to_fp16.cc
##########
@@ -0,0 +1,337 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file fp32_to_fp16.cc
+ * \brief Rewrite a graph into an fp16 form.
+ */
+#include "fp32_to_fp16.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// A function which maps CallNodes to their initial conversion color
+using ColorFunc = std::function<FP16ConversionCategory(const CallNode*)>;
+
+// A function which maps green CallNodes to wanted accumulation and output dtypes
+using OutputDtypeFunc = std::function<FP16OpDType(const CallNode*)>;
+
+class AmpGraphCreator : public ExprMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+  const ColorFunc colorer;
+  const OutputDtypeFunc output_dtype_func;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs new_attrs = Attrs(call->attrs);
+    if (new_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = new_attrs.as<Conv1DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv1DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DeformableConv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DenseAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<BatchMatmulAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = new_attrs.as<InitOpAttrs>()) {
+        ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return new_attrs;
+  }
+
+  template <typename T>
+  void ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    T* mutable_attrs = const_cast<T*>(attrs);
+
+    DataType cur_type = (mutable_attrs->out_dtype);
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->out_dtype = accumulation_dtype;
+  }
+
+  template <typename T>
+  void ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    T* mutable_attrs = const_cast<T*>(attrs);
+    DataType cur_type = (mutable_attrs->dtype);
+
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->dtype = accumulation_dtype;
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsFP16Type(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only fp16 elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == DataType::Float(16);
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsFP16Type(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = expr_dtype == wanted_dtype ? expr : Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;
+
+    // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
+    const ExprNode* new_expr_node = result.as<ExprNode>();
+    cast_nodes_cache[{new_expr_node, expr_dtype}] = expr;
+    return result;
+  }
+
+  Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
+    /* Helper for casting arguments to call_nodes handling all relevant cases. */
+    if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+      return CachedCast(expr, tensor_type->dtype, wanted_dtype);
+    } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
+        Expr tuple_element = GetField(expr, i);
+        Type tuple_element_dtype = (tuple_type->fields)[i];
+        Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(tuple_element);
+      }
+      return all_same ? expr : Tuple(new_expr);
+    } else {
+      LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
+      return expr;
+    }
+  }
+
+  std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args,
+                                                  const Array<Type>& cur_arg_types,
+                                                  const DataType& wanted_dtype) {
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    for (size_t i = 0; i < cur_args.size(); i++) {
+      Expr cur_arg = cur_args[i];
+      Type cur_arg_type = cur_arg_types[i];
+      Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+    }
+    return {new_args, new_arg_types};
+  }
+
+ public:
+  explicit AmpGraphCreator(ColorFunc colorer, OutputDtypeFunc output_dtype_func)
+      : ExprMutator(), colorer(colorer), output_dtype_func(output_dtype_func) {}
+
+  Expr VisitExpr_(const CallNode* call_node) {
+    FP16ConversionCategory initial_color = colorer(call_node);
+    auto new_op = this->Mutate(call_node->op);
+
+    // Mutate arguments to FP16 form first if possible and keep track of whether all floating point
+    // tensors are in FP16 form already. This is useful for propagating color.
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    bool all_args_fp16_compatible = true;
+    for (Expr arg : call_node->args) {
+      Expr new_arg = this->Mutate(arg);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+
+      if (all_args_fp16_compatible) {
+        // We can cast Vars and Constants to the right types so don't care about the types.
+        bool is_fp16_compatible = IsFP16Type(new_arg_type, true) || arg->IsInstance<VarNode>() ||
+                                  arg->IsInstance<ConstantNode>();
+        all_args_fp16_compatible &= is_fp16_compatible;
+      }
+    }
+
+    // Determine the final color.
+    FP16ConversionCategory final_color;
+    if (initial_color == GRAY) {
+      final_color = all_args_fp16_compatible ? GREEN : RED;

Review comment:
       On closer thought I will lead things as is since only some ops will benefit from the trick I described. In the future exposing this to op-attributes might be worthwhile but I cannot think of a major savings that comes from this.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on a change in pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r649467676



##########
File path: src/relay/transforms/fp32_to_fp16.cc
##########
@@ -0,0 +1,332 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file fp32_to_fp16.cc
+ * \brief Rewrite a graph into an fp16 form.
+ */
+#include "fp32_to_fp16.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    return h1 ^ (h2 << 1);
+  }
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// A function which maps CallNodes to their initial conversion color
+using ColorFunc = std::function<FP16ConversionCategory(const CallNode*)>;
+
+// A function which maps green CallNodes to wanted accumulation and output dtypes
+using OutputDtypeFunc = std::function<FP16OpDType(const CallNode*)>;
+
+class AmpGraphCreator : public ExprMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+  const ColorFunc colorer;
+  const OutputDtypeFunc output_dtype_func;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs new_attrs = Attrs(call->attrs);
+    if (new_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = new_attrs.as<Conv1DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv1DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DeformableConv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DenseAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<BatchMatmulAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = new_attrs.as<InitOpAttrs>()) {
+        ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return new_attrs;
+  }
+
+  template <typename T>
+  void ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    T* mutable_attrs = const_cast<T*>(attrs);
+    if ((mutable_attrs->out_dtype).is_float()) mutable_attrs->out_dtype = accumulation_dtype;
+  }
+
+  template <typename T>
+  void ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    T* mutable_attrs = const_cast<T*>(attrs);
+    if ((mutable_attrs->dtype).is_float()) mutable_attrs->dtype = accumulation_dtype;
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsFP16Type(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only fp16 elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == DataType::Float(16);
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsFP16Type(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = expr_dtype == wanted_dtype ? expr : Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;
+
+    // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
+    const ExprNode* new_expr_node = result.as<ExprNode>();
+    cast_nodes_cache[{new_expr_node, expr_dtype}] = expr;
+    return result;
+  }
+
+  Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
+    /* Helper for casting arguments to call_nodes handling all relevant cases. */
+    if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+      return CachedCast(expr, tensor_type->dtype, wanted_dtype);
+    } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
+        Expr tuple_element = GetField(expr, i);
+        Type tuple_element_dtype = (tuple_type->fields)[i];
+        Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(tuple_element);
+      }
+      return all_same ? expr : Tuple(new_expr);
+    } else {
+      LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
+      return expr;
+    }
+  }
+
+  std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args,
+                                                  const Array<Type>& cur_arg_types,
+                                                  const DataType& wanted_dtype) {
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    for (size_t i = 0; i < cur_args.size(); i++) {
+      Expr cur_arg = cur_args[i];
+      Type cur_arg_type = cur_arg_types[i];
+      Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+    }
+    return {new_args, new_arg_types};
+  }
+
+ public:
+  explicit AmpGraphCreator(ColorFunc colorer, OutputDtypeFunc output_dtype_func)
+      : ExprMutator(), colorer(colorer), output_dtype_func(output_dtype_func) {}

Review comment:
       Done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on a change in pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r648649618



##########
File path: python/tvm/relay/transform/transform.py
##########
@@ -1199,3 +1198,20 @@ def FakeQuantizationToInteger():
         The registered SimplifyExpr pass.
     """
     return _ffi_api.FakeQuantizationToInteger()
+
+
+def AMPRewrite():

Review comment:
       Hmm, I would prefer `ToMixedPrecision` still if that is fine with you.
   
   The example you list only works for me because it exists under the `amp` namespace. `AutoCast` by itself without being part of torch.cuda.amp does not show mixed precision.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on a change in pull request #8069: [Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r652192089



##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,409 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed floating point precision for relay graphs. i.e. turn a graph into fp16.
+ *
+ */
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
+// savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to
+// justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to
+// numerical reasons.
+enum MixedTypeConversionCategory : int {
+  MIXED_PRECISION_ALWAYS = 0,
+  MIXED_PRECISION_FOLLOW = 1,
+  MIXED_PRECISION_NEVER = 2
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// Return array is of type : [MixedTypeConversionCategory (int), String, String]
+// The fields are          : [ConversionCategory, accumulation_datatype, output_datatype]
+// Call is a call node, DataType is the mixed precision type
+using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc<Array<ObjectRef>(
+    const Call& call_node, const std::string& target_dtype_str)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+
+  // The target datatype we want to convert to e.g. FP16
+  const DataType mixed_precision_type;
+
+  // If false, throws a fatal error if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool ignore_missing_ops;
+
+  // If true, emits a warning if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool warn_missing_ops;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs cur_attrs = call->attrs;
+    if (cur_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = cur_attrs.as<Conv1DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv1DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DeformableConv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DenseAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<BatchMatmulAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = cur_attrs.as<InitOpAttrs>()) {
+        return ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return cur_attrs;
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    DataType cur_type = (attrs->out_dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->out_dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    DataType cur_type = (attrs->dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only target mixed precision type elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == mixed_precision_type;
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsMixedPrecisionType(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    if (expr_dtype == wanted_dtype) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;
+
+    // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
+    const ExprNode* new_expr_node = result.as<ExprNode>();
+    cast_nodes_cache[{new_expr_node, expr_dtype}] = expr;
+    return result;
+  }
+
+  Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
+    /* Helper for casting arguments to call_nodes handling all relevant cases. */
+    if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+      return CachedCast(expr, tensor_type->dtype, wanted_dtype);
+    } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
+        Expr tuple_element = GetField(expr, i);
+        Type tuple_element_dtype = (tuple_type->fields)[i];
+        Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(tuple_element);
+      }
+      return all_same ? expr : Tuple(new_expr);
+    } else {
+      LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
+      return expr;
+    }
+  }
+
+  std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args,
+                                                  const Array<Type>& cur_arg_types,
+                                                  const DataType& wanted_dtype) {
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    for (size_t i = 0; i < cur_args.size(); i++) {
+      Expr cur_arg = cur_args[i];
+      Type cur_arg_type = cur_arg_types[i];
+      Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+    }
+    return {new_args, new_arg_types};
+  }
+
+ public:
+  using MixedModeMutator::VisitExpr_;
+
+  explicit MixedPrecisionPass(DataType mixed_precision_type = DataType::Float(16),
+                              bool ignore_missing_ops = true, bool warn_missing_ops = true)
+      : MixedModeMutator(),
+        mixed_precision_type(mixed_precision_type),
+        ignore_missing_ops(ignore_missing_ops),
+        warn_missing_ops(warn_missing_ops) {
+    if (!mixed_precision_type.is_float() && !mixed_precision_type.is_bfloat16())
+      LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat16 got "
+                 << mixed_precision_type;
+  }
+
+  Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final {
+    const CallNode* post_call_node = post.as<CallNode>();
+    if (!post_call_node) {
+      LOG(FATAL) << "Expected a CallNode for the rewrite got " << post;
+    }
+
+    Expr cur_op = post_call_node->op;
+
+    // Get info on the operation being called:
+    // conversion category (int), accumulation dtype (str), output dtype (str)
+    MixedTypeConversionCategory initial_category;
+    DataType accumulation_dtype, output_dtype;
+    if (cur_op.as<FunctionNode>()) {
+      // Avoid messing with functions to avoid changing signature
+      initial_category = MIXED_PRECISION_NEVER;
+      accumulation_dtype = DataType::Float(32);
+      output_dtype = DataType::Float(32);
+    } else if (cur_op.as<OpNode>()) {
+      static auto attr_map =
+          Op::GetAttrMap<FTVMMixedPrecisionConversionType>("FTVMMixedPrecisionConversionType");
+      Op op = Downcast<Op>(cur_op);
+      if (attr_map.count(op)) {
+        // Calculate the conversion category and dtypes from registered attribute.
+        FTVMMixedPrecisionConversionType func = attr_map[op];
+        Array<ObjectRef> op_descriptor =
+            func(GetRef<Call>(pre_call_node), DLDataType2String(mixed_precision_type));
+
+        int64_t op_conversion_type = Downcast<Integer>(op_descriptor[0])->value;
+        initial_category = static_cast<MixedTypeConversionCategory>(op_conversion_type);
+        accumulation_dtype = DataType(String2DLDataType(Downcast<String>(op_descriptor[1])));
+        output_dtype = DataType(String2DLDataType(Downcast<String>(op_descriptor[2])));
+      } else {
+        if (!ignore_missing_ops) LOG(FATAL) << "Op " << op->name << " not in conversion lists!";
+        if (warn_missing_ops) LOG(WARNING) << "Op " << op->name << " not in conversion lists!";
+
+        // If not registered, by default assume is a generic FOLLOW operation.
+        initial_category = MIXED_PRECISION_FOLLOW;
+        accumulation_dtype = DataType::Float(16);
+        output_dtype = DataType::Float(16);

Review comment:
       Done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#issuecomment-857306109


   > Can we get a few more initial reviews - @mbrookhart , @csullivan?
   > 
   > @AndrewZhaoLuo I would also suggest to test a dynamic model like SSD or Mask-RCNN. Your current list of Object detection models involve Yolo which is static model.
   
   I added tried it on an SSD model and it seems to work fine. Mask-RCNN I haven't found a spare file which can convert well and be run normally in FP32. 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] comaniac commented on a change in pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r648541663



##########
File path: src/relay/transforms/fp32_to_fp16.cc
##########
@@ -0,0 +1,337 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file fp32_to_fp16.cc

Review comment:
       I would suggest using `amp.cc`/`amp.h` directly, as it should not be limited to FP32 to FP16.

##########
File path: src/relay/transforms/fp32_to_fp16.h
##########
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file fp32_to_fp16.h
+ * \brief Utilities and common types used for FP32->FP16 pass.
+ */
+#ifndef TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_
+#define TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_
+
+#include <tvm/ir/op.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/function.h>
+
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relay {
+
+struct FP16OpDType {
+  DataType accumulation_dtype;
+  DataType output_dtype;
+};
+
+// GREEN colored ops should always be done in FP16 due to the speed and memory savings
+// GRAY colored ops can be done in FP16 but don't have speedups to justify a dedicated cast.
+// RED colored ops should not be done in FP16 due to numerical reasons.
+enum FP16ConversionCategory { RED, GRAY, GREEN };
+
+using OpStringSet = std::unordered_set<std::string>;
+
+// Default lists inspired from TF's classifications:

Review comment:
       I don't prefer to specify op lists in a pass. It means we need to maintain this pass every time we add a new op. It would be better to follow the logic of other similar passes: Register an attribute to each op. If an op doesn't have this attribute registered, using the default behavior. It is also impossible for this implementation to accept user-defined rules from Python.

##########
File path: src/relay/transforms/fp32_to_fp16.h
##########
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file fp32_to_fp16.h
+ * \brief Utilities and common types used for FP32->FP16 pass.
+ */
+#ifndef TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_
+#define TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_
+
+#include <tvm/ir/op.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/function.h>
+
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relay {
+
+struct FP16OpDType {
+  DataType accumulation_dtype;
+  DataType output_dtype;
+};
+
+// GREEN colored ops should always be done in FP16 due to the speed and memory savings
+// GRAY colored ops can be done in FP16 but don't have speedups to justify a dedicated cast.
+// RED colored ops should not be done in FP16 due to numerical reasons.
+enum FP16ConversionCategory { RED, GRAY, GREEN };

Review comment:
       Some suggestions:
   1. There are more like an attribute instead of category.
   2. Use straightforward terms, such as  ALWAYS/FOLLOW/NEVER, instead of RED/GRAY/GREEN.
   3. Emit warning for non-specified ops may result in tedious messages. We could make it configurable to let users decide whether to print out these ops.

##########
File path: src/relay/transforms/fp32_to_fp16.cc
##########
@@ -0,0 +1,337 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file fp32_to_fp16.cc
+ * \brief Rewrite a graph into an fp16 form.
+ */
+#include "fp32_to_fp16.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// A function which maps CallNodes to their initial conversion color
+using ColorFunc = std::function<FP16ConversionCategory(const CallNode*)>;
+
+// A function which maps green CallNodes to wanted accumulation and output dtypes
+using OutputDtypeFunc = std::function<FP16OpDType(const CallNode*)>;
+
+class AmpGraphCreator : public ExprMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+  const ColorFunc colorer;
+  const OutputDtypeFunc output_dtype_func;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs new_attrs = Attrs(call->attrs);
+    if (new_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = new_attrs.as<Conv1DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv1DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DeformableConv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DenseAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<BatchMatmulAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = new_attrs.as<InitOpAttrs>()) {
+        ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return new_attrs;
+  }
+
+  template <typename T>
+  void ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    T* mutable_attrs = const_cast<T*>(attrs);
+
+    DataType cur_type = (mutable_attrs->out_dtype);
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->out_dtype = accumulation_dtype;
+  }
+
+  template <typename T>
+  void ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    T* mutable_attrs = const_cast<T*>(attrs);
+    DataType cur_type = (mutable_attrs->dtype);
+
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->dtype = accumulation_dtype;
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsFP16Type(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only fp16 elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == DataType::Float(16);
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsFP16Type(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = expr_dtype == wanted_dtype ? expr : Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;
+
+    // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
+    const ExprNode* new_expr_node = result.as<ExprNode>();
+    cast_nodes_cache[{new_expr_node, expr_dtype}] = expr;
+    return result;
+  }
+
+  Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
+    /* Helper for casting arguments to call_nodes handling all relevant cases. */
+    if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+      return CachedCast(expr, tensor_type->dtype, wanted_dtype);
+    } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
+        Expr tuple_element = GetField(expr, i);
+        Type tuple_element_dtype = (tuple_type->fields)[i];
+        Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(tuple_element);
+      }
+      return all_same ? expr : Tuple(new_expr);
+    } else {
+      LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
+      return expr;
+    }
+  }
+
+  std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args,
+                                                  const Array<Type>& cur_arg_types,
+                                                  const DataType& wanted_dtype) {
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    for (size_t i = 0; i < cur_args.size(); i++) {
+      Expr cur_arg = cur_args[i];
+      Type cur_arg_type = cur_arg_types[i];
+      Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+    }
+    return {new_args, new_arg_types};
+  }
+
+ public:
+  explicit AmpGraphCreator(ColorFunc colorer, OutputDtypeFunc output_dtype_func)
+      : ExprMutator(), colorer(colorer), output_dtype_func(output_dtype_func) {}
+
+  Expr VisitExpr_(const CallNode* call_node) {
+    FP16ConversionCategory initial_color = colorer(call_node);
+    auto new_op = this->Mutate(call_node->op);
+
+    // Mutate arguments to FP16 form first if possible and keep track of whether all floating point
+    // tensors are in FP16 form already. This is useful for propagating color.
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    bool all_args_fp16_compatible = true;
+    for (Expr arg : call_node->args) {
+      Expr new_arg = this->Mutate(arg);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+
+      if (all_args_fp16_compatible) {
+        // We can cast Vars and Constants to the right types so don't care about the types.
+        bool is_fp16_compatible = IsFP16Type(new_arg_type, true) || arg->IsInstance<VarNode>() ||
+                                  arg->IsInstance<ConstantNode>();
+        all_args_fp16_compatible &= is_fp16_compatible;
+      }
+    }
+
+    // Determine the final color.
+    FP16ConversionCategory final_color;
+    if (initial_color == GRAY) {
+      final_color = all_args_fp16_compatible ? GREEN : RED;
+    } else {
+      final_color = initial_color;
+    }
+
+    // Create the new arguments to the call.
+    DataType wanted_arg_dtypes = final_color == GREEN ? DataType::Float(16) : DataType::Float(32);
+    auto call_args_and_types = CastAllArgs(new_args, new_arg_types, wanted_arg_dtypes);
+
+    Array<Expr> call_args = call_args_and_types.first;
+    Array<Type> call_arg_types;
+
+    if (call_node->op.as<FunctionNode>()) {
+      // Function Nodes don't store type info in the Call, it should be a []
+      call_arg_types = call_node->type_args;
+    } else {
+      call_arg_types = call_args_and_types.second;
+    }
+
+    // Finally create the new attributes.
+    if (final_color == GREEN) {
+      FP16OpDType output_dtypes = output_dtype_func(call_node);
+
+      Attrs new_attrs = GetNewAttrs(call_node, output_dtypes.accumulation_dtype);
+      Expr output = Call(new_op, call_args, new_attrs, call_arg_types, call_node->span);
+      if (output_dtypes.accumulation_dtype != output_dtypes.output_dtype) {
+        output = CastArg(output, GetType(output), output_dtypes.output_dtype);

Review comment:
       Wouldn't this introduce unnecessary cast ops? For example, the accumulation dtype is FP32 and the followed op is RED. Will this make it `A(GREEN) - cast_to_fp16 - cast_to_fp32 - B(RED)`?

##########
File path: src/relay/transforms/fp32_to_fp16.cc
##########
@@ -0,0 +1,337 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file fp32_to_fp16.cc
+ * \brief Rewrite a graph into an fp16 form.
+ */
+#include "fp32_to_fp16.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// A function which maps CallNodes to their initial conversion color
+using ColorFunc = std::function<FP16ConversionCategory(const CallNode*)>;
+
+// A function which maps green CallNodes to wanted accumulation and output dtypes
+using OutputDtypeFunc = std::function<FP16OpDType(const CallNode*)>;
+
+class AmpGraphCreator : public ExprMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+  const ColorFunc colorer;
+  const OutputDtypeFunc output_dtype_func;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs new_attrs = Attrs(call->attrs);
+    if (new_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = new_attrs.as<Conv1DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv1DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DeformableConv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DenseAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<BatchMatmulAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = new_attrs.as<InitOpAttrs>()) {
+        ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return new_attrs;
+  }
+
+  template <typename T>
+  void ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    T* mutable_attrs = const_cast<T*>(attrs);
+
+    DataType cur_type = (mutable_attrs->out_dtype);
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->out_dtype = accumulation_dtype;
+  }
+
+  template <typename T>
+  void ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    T* mutable_attrs = const_cast<T*>(attrs);
+    DataType cur_type = (mutable_attrs->dtype);
+
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->dtype = accumulation_dtype;
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsFP16Type(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only fp16 elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == DataType::Float(16);
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsFP16Type(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = expr_dtype == wanted_dtype ? expr : Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;
+
+    // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
+    const ExprNode* new_expr_node = result.as<ExprNode>();
+    cast_nodes_cache[{new_expr_node, expr_dtype}] = expr;
+    return result;
+  }
+
+  Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
+    /* Helper for casting arguments to call_nodes handling all relevant cases. */
+    if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+      return CachedCast(expr, tensor_type->dtype, wanted_dtype);
+    } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
+        Expr tuple_element = GetField(expr, i);
+        Type tuple_element_dtype = (tuple_type->fields)[i];
+        Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(tuple_element);
+      }
+      return all_same ? expr : Tuple(new_expr);
+    } else {
+      LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
+      return expr;
+    }
+  }
+
+  std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args,
+                                                  const Array<Type>& cur_arg_types,
+                                                  const DataType& wanted_dtype) {
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    for (size_t i = 0; i < cur_args.size(); i++) {
+      Expr cur_arg = cur_args[i];
+      Type cur_arg_type = cur_arg_types[i];
+      Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+    }
+    return {new_args, new_arg_types};
+  }
+
+ public:
+  explicit AmpGraphCreator(ColorFunc colorer, OutputDtypeFunc output_dtype_func)
+      : ExprMutator(), colorer(colorer), output_dtype_func(output_dtype_func) {}
+
+  Expr VisitExpr_(const CallNode* call_node) {
+    FP16ConversionCategory initial_color = colorer(call_node);
+    auto new_op = this->Mutate(call_node->op);
+
+    // Mutate arguments to FP16 form first if possible and keep track of whether all floating point
+    // tensors are in FP16 form already. This is useful for propagating color.
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    bool all_args_fp16_compatible = true;
+    for (Expr arg : call_node->args) {
+      Expr new_arg = this->Mutate(arg);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+
+      if (all_args_fp16_compatible) {
+        // We can cast Vars and Constants to the right types so don't care about the types.
+        bool is_fp16_compatible = IsFP16Type(new_arg_type, true) || arg->IsInstance<VarNode>() ||
+                                  arg->IsInstance<ConstantNode>();
+        all_args_fp16_compatible &= is_fp16_compatible;
+      }
+    }

Review comment:
       You don't need to perform the entire logic when initial_color is not GRAY.

##########
File path: python/tvm/relay/transform/transform.py
##########
@@ -1199,3 +1198,20 @@ def FakeQuantizationToInteger():
         The registered SimplifyExpr pass.
     """
     return _ffi_api.FakeQuantizationToInteger()
+
+
+def AMPRewrite():

Review comment:
       Since this is the user API, we might need to think a bit more to make it more straightforward. For example, `AMP` or `AutoCast` are better naming IMHO.

##########
File path: src/relay/transforms/fp32_to_fp16.cc
##########
@@ -0,0 +1,337 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file fp32_to_fp16.cc
+ * \brief Rewrite a graph into an fp16 form.
+ */
+#include "fp32_to_fp16.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// A function which maps CallNodes to their initial conversion color
+using ColorFunc = std::function<FP16ConversionCategory(const CallNode*)>;
+
+// A function which maps green CallNodes to wanted accumulation and output dtypes
+using OutputDtypeFunc = std::function<FP16OpDType(const CallNode*)>;
+
+class AmpGraphCreator : public ExprMutator {

Review comment:
       Please use the iterative implementation instead of recursion to avoid stack overflow.

##########
File path: src/relay/transforms/fp32_to_fp16.h
##########
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file fp32_to_fp16.h
+ * \brief Utilities and common types used for FP32->FP16 pass.
+ */
+#ifndef TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_
+#define TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_
+
+#include <tvm/ir/op.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/function.h>
+
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relay {
+
+struct FP16OpDType {
+  DataType accumulation_dtype;
+  DataType output_dtype;
+};
+
+// GREEN colored ops should always be done in FP16 due to the speed and memory savings
+// GRAY colored ops can be done in FP16 but don't have speedups to justify a dedicated cast.
+// RED colored ops should not be done in FP16 due to numerical reasons.
+enum FP16ConversionCategory { RED, GRAY, GREEN };
+
+using OpStringSet = std::unordered_set<std::string>;
+
+// Default lists inspired from TF's classifications:
+// github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h
+// They have a bias toward Nvidia Tensor Cores so modify lists per your hardware choice.
+OpStringSet DEFAULT_GREEN_LIST({
+    "nn.conv1d",
+    "nn.conv2d",
+    "nn.conv3d",
+    "nn.conv1d_transpose",
+    "nn.conv2d_transpose",
+    "nn.conv3d_transpose",
+    "nn.dense",
+    "nn.batch_matmul",
+});
+OpStringSet DEFAULT_GRAY_LIST({
+    // These ops add new data or change shape
+    "nn.pad",
+    "nn.batch_flatten",
+    "concatenate",
+    "zeros",
+    "split",
+    "squeeze",
+    "transpose",
+    "expand_dims",
+    "reshape",
+    "dyn.reshape",
+    "broadcast_to_like",
+    "dyn.broadcast_to",
+    "strided_slice",
+    "dyn.strided_slice",
+    "take",
+    "argwhere",
+    "where",
+    "tile",
+    "dyn.tile",
+    "scatter",
+    "full",
+    "dyn.full",
+    // Comparison
+    "less",
+    "greater",
+    "less_equal",
+    "greater_equal",
+    // By definition copy and cast will become green or red based on inputs
+    "copy",
+    "cast",
+    "cast_like",
+    // Simple arithmetic
+    "add",
+    "subtract",
+    "multiply",
+    "divide",
+    "nn.bias_add",
+    "nn.batch_norm",
+    "sum",
+    "mean",
+    "sqrt",
+    "shape_of",
+    // Simple activations
+    "max",
+    "min",
+    "maximum",
+    "minimum",
+    "nn.relu",
+    "nn.leaky_relu",
+    "nn.prelu",
+    "nn.dropout",
+    // Complicated activations which saturate in a narrow range
+    "sigmoid",
+    "tanh",
+    // Pooling operations
+    "nn.max_pool1d",
+    "nn.max_pool2d",
+    "nn.max_pool3d",
+    "nn.avg_pool1d",
+    "nn.avg_pool2d",
+    "nn.avg_pool3d",
+    // "nn.global_max_pool1d", // does not exist yet
+    "nn.global_max_pool2d",
+    // "nn.global_max_pool3d", // does not exist yet
+    // "nn.global_avg_pool1d", // does not exist yet
+    "nn.global_avg_pool2d",
+    // "nn.global_avg_pool3d", // does not exist yet
+    "nn.adaptive_max_pool1d",
+    "nn.adaptive_max_pool2d",
+    "nn.adaptive_max_pool3d",
+    "nn.adaptive_avg_pool1d",
+    "nn.adaptive_avg_pool2d",
+    "nn.adaptive_avg_pool3d",
+});
+OpStringSet DEFAULT_RED_LIST({
+    // In general if |f(x)| >> |x| for expected inputs then put the op here.
+    "exp",
+    "power",
+    "nn.cross_entropy",
+    "nn.cross_entropy_with_logits",
+    "nn.softmax",
+    "nn.l2_normalize",
+    // Error function doesn't seem to be able to be lowered into fp16 version in llvm.
+    // Move to gray list when it does.
+    "erf",
+});
+
+class DefaultFP16Colorer {
+  /* The default class to initially color ops for conversion using lists.
+
+  Creates a callable which given a CallNode* returns the node's color.
+  */

Review comment:
       Please use the formal docstring format.

##########
File path: tests/python/relay/test_fp32_to_fp16_transform.py
##########
@@ -0,0 +1,328 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for testing FP32 -> FP16 pass"""
+from typing import Any, Dict, List
+
+import numpy as np
+import tvm
+from tvm import relay
+from tvm.relay.testing import lstm
+from tvm.relay.transform import AMPRewrite
+from tvm.relay.transform.transform import InferType
+
+
+def run_module(mod: tvm.runtime.Module, mod_params: Dict[str, Any]) -> List:
+    dev = tvm.device("llvm", 0)
+    intrp = relay.create_executor("debug", mod, device=dev, target="llvm")
+    result = intrp.evaluate()(**mod_params)
+    if isinstance(result, tvm.runtime.container.ADT):
+        result = [r.asnumpy() for r in result]
+        return result
+    else:
+        return [result.asnumpy()]
+
+
+def verify_fp32_fp16_output_close(
+    mod: tvm.runtime.Module, mod_params: Dict[str, Any], rtol: float = 1e-3, atol: float = 0
+) -> tvm.runtime.Module:
+    mod = InferType()(mod)
+    result_fp32 = run_module(mod, mod_params)
+    fp16_mod = AMPRewrite()(mod)
+    result_fp16 = run_module(fp16_mod, mod_params)
+
+    # Ensure the results are close
+    for fp32, fp16 in zip(result_fp32, result_fp16):
+        np.testing.assert_allclose(fp32, fp16, rtol=rtol, atol=atol)
+
+    return fp16_mod
+
+
+def test_lstm():
+    """A small stress test on a single unrolled lstm unit.
+
+    Has internal functions and let statements the pass must work on.
+    """
+    np.random.seed(5628)
+    units = 3
+    iterations = 5
+    mod, mod_params = lstm.get_workload(iterations=iterations, num_hidden=units)
+
+    # This is an unrolled lstm so each data should be the previous results but
+    # we don't care, we just want to stress test things.
+    for i in range(iterations):
+        mod_params["data" if i == 0 else f"data{i}"] = np.random.uniform(
+            -10, 10, (1, units)
+        ).astype("float32")
+
+    verify_fp32_fp16_output_close(mod, mod_params, rtol=0.01, atol=0.01)
+
+
+def test_convert_single_conv():
+    """Conv is a green listed operation meaning it will always use fp16 workload.
+
+    By default it accumulates to fp32 and outputs fp16.
+    """
+    np.random.seed(208)
+
+    data_shape = (1, 3, 32, 32)
+    weight_shape = (5, 3, 3, 3)
+    data = relay.var("data", shape=data_shape, dtype="float32")
+    weight = relay.var("weight", shape=weight_shape, dtype="float32")
+    conv = relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float32")
+    mod = tvm.IRModule.from_expr(conv)
+    mod = tvm.relay.transform.InferType()(mod)
+
+    mod_params = {
+        "data": np.random.uniform(-1, 1, size=data_shape).astype("float32"),
+        "weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"),
+    }
+    fp16_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.01, rtol=1e-3)
+
+    expected_mod = tvm.IRModule.from_expr(
+        relay.cast(
+            relay.nn.conv2d(
+                relay.cast(data, "float16"),
+                relay.cast(weight, "float16"),
+                strides=(1, 1),
+                padding=(1, 1),
+                out_dtype="float32",
+            ),
+            "float16",
+        )
+    )
+    expected_mod = tvm.relay.transform.InferType()(expected_mod)
+
+    assert not tvm.ir.structural_equal(fp16_mod, mod)
+    assert tvm.ir.structural_equal(fp16_mod, expected_mod)
+
+
+def test_convert_conv_bn():
+    """Conv is green and batch norm is gray. As Conv should output fp16 batch_norm should be green."""
+    np.random.seed(208)
+
+    data_shape = (1, 3, 32, 32)
+    weight_shape = (5, 3, 3, 3)
+    data = relay.var("data", shape=data_shape, dtype="float32")
+    weight = relay.var("weight", shape=weight_shape, dtype="float32")
+    conv = relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float32")
+
+    bn_shape = [5]
+    gamma = relay.var("gamma", shape=bn_shape)
+    beta = relay.var("beta", shape=bn_shape)
+    moving_mean = relay.var("moving_mean", shape=bn_shape)
+    moving_var = relay.var("moving_var", shape=bn_shape)
+    bn = relay.nn.batch_norm(conv, gamma, beta, moving_mean, moving_var)
+    mod = tvm.IRModule.from_expr(bn[0])
+    mod = tvm.relay.transform.InferType()(mod)
+
+    mod_params = {
+        "data": np.random.uniform(-1, 1, size=data_shape).astype("float32"),
+        "weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"),
+        "gamma": np.random.uniform(-1, 1, size=bn_shape).astype("float32"),
+        "beta": np.random.uniform(-1, 1, size=bn_shape).astype("float32"),
+        "moving_mean": np.random.uniform(-1, 1, size=bn_shape).astype("float32"),
+        "moving_var": np.random.uniform(-1, 1, size=bn_shape).astype("float32"),
+    }
+    fp16_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.01, rtol=1e-3)
+
+    # Creating expected module
+    data = relay.cast(relay.var("data", shape=data_shape), "float16")
+    weight = relay.cast(relay.var("weight", shape=weight_shape), "float16")
+    conv = relay.cast(
+        relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float32"),
+        "float16",
+    )
+
+    bn_shape = [5]
+    gamma = relay.cast(relay.var("gamma", shape=bn_shape), "float16")
+    beta = relay.cast(relay.var("beta", shape=bn_shape), "float16")
+    moving_mean = relay.cast(relay.var("moving_mean", shape=bn_shape), "float16")
+    moving_var = relay.cast(relay.var("moving_var", shape=bn_shape), "float16")
+    bn = relay.nn.batch_norm(conv, gamma, beta, moving_mean, moving_var)
+
+    expected_mod = tvm.IRModule.from_expr(bn[0])
+    expected_mod = tvm.relay.transform.InferType()(expected_mod)
+    assert not tvm.ir.structural_equal(fp16_mod, mod)
+    assert tvm.ir.structural_equal(fp16_mod, expected_mod)
+
+
+def test_do_not_convert_softmax():
+    """Softmax is a red listed operation and therefore should never be fp16."""
+    np.random.seed(209)
+    shape = [1, 2, 3]
+    a = relay.var("a", shape=shape)
+    b = relay.nn.softmax(a)
+    mod = tvm.IRModule.from_expr(b)
+    mod = tvm.relay.transform.InferType()(mod)
+
+    mod_params = {
+        "a": np.random.uniform(-1, 1, size=shape).astype("float32"),
+    }
+    output_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.0, rtol=0)
+    assert tvm.ir.structural_equal(mod, output_mod)
+
+
+def test_green_gray_propagates_simple():
+    """Conv is a green listed operation, while addition is gray.
+
+    As Conv outputs fp16 the add should be done in fp16.
+    """
+    np.random.seed(210)
+    data_shape = (1, 3, 32, 32)
+    weight_shape = (5, 3, 3, 3)
+    data = relay.var("data", shape=data_shape, dtype="float32")
+    weight = relay.var("weight", shape=weight_shape, dtype="float32")
+    conv = relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float32")
+    conv = conv + conv
+    mod = tvm.IRModule.from_expr(conv)
+    mod = tvm.relay.transform.InferType()(mod)
+
+    mod_params = {
+        "data": np.random.uniform(-1, 1, size=data_shape).astype("float32"),
+        "weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"),
+    }
+    fp16_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.01, rtol=1e-3)
+
+    conv_expr = relay.cast(
+        relay.nn.conv2d(
+            relay.cast(data, "float16"),
+            relay.cast(weight, "float16"),
+            strides=(1, 1),
+            padding=(1, 1),
+            out_dtype="float32",
+        ),
+        "float16",
+    )
+    expected_mod = tvm.IRModule.from_expr(conv_expr + conv_expr)
+    expected_mod = tvm.relay.transform.InferType()(expected_mod)
+
+    assert not tvm.ir.structural_equal(fp16_mod, mod)
+    assert tvm.ir.structural_equal(fp16_mod, expected_mod)
+
+
+def test_red_gray_propagates_simple():
+    """Everything after a softmax should be in FP32 (exception green colored ops)"""
+    np.random.seed(211)
+    shape = [1, 2, 3]
+    a = relay.var("a", shape=shape)
+    b = relay.nn.softmax(a)
+    c = b + b
+    mod = tvm.IRModule.from_expr(c)
+    mod = tvm.relay.transform.InferType()(mod)
+
+    mod_params = {
+        "a": np.random.uniform(-1, 1, size=shape).astype("float32"),
+    }
+    output_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.0, rtol=0.0)
+
+    assert tvm.ir.structural_equal(mod, output_mod)
+
+
+def test_let_statement_simple():
+    """A 'simple' let statement example.
+
+    Noticable is the mutation of the bound variable types.
+    """
+    np.random.seed(211)
+    var1 = relay.var("var1", shape=[1, 20])
+    var2 = relay.var("var2", shape=[1, 20])
+
+    data = relay.var("data", shape=[1, 20])
+    weight = relay.var("weight", shape=[20, 20])
+
+    r1 = var1 + var1
+
+    r2 = var2 + var2
+    let2 = relay.Let(var2, relay.nn.dense(r1, weight, units=20), r2)
+    let1 = relay.Let(var1, relay.nn.dense(data, weight, units=20), let2)
+
+    mod = tvm.IRModule.from_expr(let1)
+    mod_params = {
+        "data": np.random.uniform(-1, 1, size=[1, 20]).astype("float32"),
+        "weight": np.random.uniform(-1, 1, size=[20, 20]).astype("float32"),
+    }
+    output_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.01, rtol=0.01)
+
+    # Construct expected structure
+    var1 = relay.var("var1", shape=[1, 20], dtype="float16")
+    var2 = relay.var("var2", shape=[1, 20], dtype="float16")
+    data = relay.cast(relay.var("data", shape=[1, 20]), "float16")
+    weight = relay.cast(relay.var("weight", shape=[20, 20]), "float16")
+    r1 = var1 + var1
+    r2 = var2 + var2
+    let2 = relay.Let(
+        var2,
+        relay.cast(relay.nn.dense(r1, weight, units=20, out_dtype="float32"), "float16"),
+        r2,
+    )
+    let1 = relay.Let(
+        var1,
+        relay.cast(relay.nn.dense(data, weight, units=20, out_dtype="float32"), "float16"),
+        let2,
+    )
+    expected_mod = tvm.IRModule.from_expr(let1)
+    expected_mod = InferType()(expected_mod)
+
+    assert tvm.ir.structural_equal(expected_mod, output_mod)
+
+
+def test_where_simple():
+    data = relay.var("data", shape=[1, 20])
+    weight = relay.var("weight", shape=[20, 20])
+    a = relay.nn.dense(data, weight, units=20)
+    b = relay.where(data, a, a)
+    mod = tvm.IRModule.from_expr(b)
+    mod_params = {
+        "data": np.random.uniform(-1, 1, size=[1, 20]).astype("float32"),
+        "weight": np.random.uniform(-1, 1, size=[20, 20]).astype("float32"),
+    }
+
+    output_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.01, rtol=0.01)
+
+    # Create expected module
+    data = relay.cast(relay.var("data", shape=[1, 20]), "float16")
+    weight = relay.cast(relay.var("weight", shape=[20, 20]), "float16")
+    a = relay.cast(relay.nn.dense(data, weight, units=20, out_dtype="float32"), "float16")
+    b = relay.where(data, a, a)
+    expected_mod = tvm.IRModule.from_expr(b)
+    expected_mod = InferType()(expected_mod)
+
+    assert tvm.ir.structural_equal(expected_mod, output_mod)
+
+
+def test_batch_matmul_simple():
+    """Batch matmul is a special case where we try to accumulate to fp16.
+
+    This is due to the fact heterogenous accumulation dtypes does not work
+    on all platforms at the moment.
+    """
+    data = relay.var("data", shape=[1, 1, 20])
+    weight = relay.var("weight", shape=[1, 20, 20])
+    a = relay.nn.batch_matmul(data, weight)
+    mod = tvm.IRModule.from_expr(a)
+    mod_params = {
+        "data": np.random.uniform(-1, 1, size=[1, 1, 20]).astype("float32"),
+        "weight": np.random.uniform(-1, 1, size=[1, 20, 20]).astype("float32"),
+    }
+    output_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.01, rtol=0.01)
+    # Create expected module
+    data = relay.cast(relay.var("data", shape=[1, 1, 20]), "float16")
+    weight = relay.cast(relay.var("weight", shape=[1, 20, 20]), "float16")
+    a = relay.nn.batch_matmul(data, weight, out_dtype="float16")
+    expected_mod = tvm.IRModule.from_expr(a)
+    expected_mod = InferType()(expected_mod)
+    assert tvm.ir.structural_equal(expected_mod, output_mod)

Review comment:
       Add the following to the end.
   
   ```python
   if __name__ == "__main__":
       pytest.main([__file__])
   ```

##########
File path: src/relay/transforms/fp32_to_fp16.cc
##########
@@ -0,0 +1,337 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file fp32_to_fp16.cc
+ * \brief Rewrite a graph into an fp16 form.
+ */
+#include "fp32_to_fp16.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// A function which maps CallNodes to their initial conversion color
+using ColorFunc = std::function<FP16ConversionCategory(const CallNode*)>;
+
+// A function which maps green CallNodes to wanted accumulation and output dtypes
+using OutputDtypeFunc = std::function<FP16OpDType(const CallNode*)>;
+
+class AmpGraphCreator : public ExprMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+  const ColorFunc colorer;
+  const OutputDtypeFunc output_dtype_func;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs new_attrs = Attrs(call->attrs);
+    if (new_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = new_attrs.as<Conv1DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv1DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DeformableConv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DenseAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<BatchMatmulAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = new_attrs.as<InitOpAttrs>()) {
+        ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return new_attrs;
+  }
+
+  template <typename T>
+  void ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    T* mutable_attrs = const_cast<T*>(attrs);
+
+    DataType cur_type = (mutable_attrs->out_dtype);
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->out_dtype = accumulation_dtype;
+  }
+
+  template <typename T>
+  void ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    T* mutable_attrs = const_cast<T*>(attrs);
+    DataType cur_type = (mutable_attrs->dtype);
+
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->dtype = accumulation_dtype;
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsFP16Type(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only fp16 elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == DataType::Float(16);
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsFP16Type(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = expr_dtype == wanted_dtype ? expr : Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;
+
+    // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
+    const ExprNode* new_expr_node = result.as<ExprNode>();
+    cast_nodes_cache[{new_expr_node, expr_dtype}] = expr;
+    return result;
+  }
+
+  Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
+    /* Helper for casting arguments to call_nodes handling all relevant cases. */
+    if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+      return CachedCast(expr, tensor_type->dtype, wanted_dtype);
+    } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
+        Expr tuple_element = GetField(expr, i);
+        Type tuple_element_dtype = (tuple_type->fields)[i];
+        Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(tuple_element);
+      }
+      return all_same ? expr : Tuple(new_expr);
+    } else {
+      LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
+      return expr;
+    }
+  }
+
+  std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args,
+                                                  const Array<Type>& cur_arg_types,
+                                                  const DataType& wanted_dtype) {
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    for (size_t i = 0; i < cur_args.size(); i++) {
+      Expr cur_arg = cur_args[i];
+      Type cur_arg_type = cur_arg_types[i];
+      Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+    }
+    return {new_args, new_arg_types};
+  }
+
+ public:
+  explicit AmpGraphCreator(ColorFunc colorer, OutputDtypeFunc output_dtype_func)
+      : ExprMutator(), colorer(colorer), output_dtype_func(output_dtype_func) {}
+
+  Expr VisitExpr_(const CallNode* call_node) {
+    FP16ConversionCategory initial_color = colorer(call_node);
+    auto new_op = this->Mutate(call_node->op);
+
+    // Mutate arguments to FP16 form first if possible and keep track of whether all floating point
+    // tensors are in FP16 form already. This is useful for propagating color.
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    bool all_args_fp16_compatible = true;
+    for (Expr arg : call_node->args) {
+      Expr new_arg = this->Mutate(arg);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+
+      if (all_args_fp16_compatible) {
+        // We can cast Vars and Constants to the right types so don't care about the types.
+        bool is_fp16_compatible = IsFP16Type(new_arg_type, true) || arg->IsInstance<VarNode>() ||
+                                  arg->IsInstance<ConstantNode>();
+        all_args_fp16_compatible &= is_fp16_compatible;
+      }
+    }
+
+    // Determine the final color.
+    FP16ConversionCategory final_color;
+    if (initial_color == GRAY) {
+      final_color = all_args_fp16_compatible ? GREEN : RED;
+    } else {
+      final_color = initial_color;
+    }
+
+    // Create the new arguments to the call.
+    DataType wanted_arg_dtypes = final_color == GREEN ? DataType::Float(16) : DataType::Float(32);
+    auto call_args_and_types = CastAllArgs(new_args, new_arg_types, wanted_arg_dtypes);
+
+    Array<Expr> call_args = call_args_and_types.first;
+    Array<Type> call_arg_types;
+
+    if (call_node->op.as<FunctionNode>()) {
+      // Function Nodes don't store type info in the Call, it should be a []
+      call_arg_types = call_node->type_args;
+    } else {
+      call_arg_types = call_args_and_types.second;
+    }
+
+    // Finally create the new attributes.
+    if (final_color == GREEN) {
+      FP16OpDType output_dtypes = output_dtype_func(call_node);
+
+      Attrs new_attrs = GetNewAttrs(call_node, output_dtypes.accumulation_dtype);
+      Expr output = Call(new_op, call_args, new_attrs, call_arg_types, call_node->span);
+      if (output_dtypes.accumulation_dtype != output_dtypes.output_dtype) {
+        output = CastArg(output, GetType(output), output_dtypes.output_dtype);
+      }
+      return output;
+    } else {
+      return Call(new_op, call_args, call_node->attrs, call_arg_types, call_node->span);
+    }

Review comment:
       ```suggestion
       }
       return Call(new_op, call_args, call_node->attrs, call_arg_types, call_node->span);
   ```

##########
File path: src/relay/transforms/fp32_to_fp16.cc
##########
@@ -0,0 +1,337 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file fp32_to_fp16.cc
+ * \brief Rewrite a graph into an fp16 form.
+ */
+#include "fp32_to_fp16.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// A function which maps CallNodes to their initial conversion color
+using ColorFunc = std::function<FP16ConversionCategory(const CallNode*)>;
+
+// A function which maps green CallNodes to wanted accumulation and output dtypes
+using OutputDtypeFunc = std::function<FP16OpDType(const CallNode*)>;
+
+class AmpGraphCreator : public ExprMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+  const ColorFunc colorer;
+  const OutputDtypeFunc output_dtype_func;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs new_attrs = Attrs(call->attrs);
+    if (new_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = new_attrs.as<Conv1DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv1DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DeformableConv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DenseAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<BatchMatmulAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = new_attrs.as<InitOpAttrs>()) {
+        ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return new_attrs;
+  }
+
+  template <typename T>
+  void ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    T* mutable_attrs = const_cast<T*>(attrs);
+
+    DataType cur_type = (mutable_attrs->out_dtype);
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->out_dtype = accumulation_dtype;
+  }
+
+  template <typename T>
+  void ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    T* mutable_attrs = const_cast<T*>(attrs);
+    DataType cur_type = (mutable_attrs->dtype);
+
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->dtype = accumulation_dtype;
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsFP16Type(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only fp16 elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == DataType::Float(16);
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsFP16Type(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = expr_dtype == wanted_dtype ? expr : Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;
+
+    // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
+    const ExprNode* new_expr_node = result.as<ExprNode>();
+    cast_nodes_cache[{new_expr_node, expr_dtype}] = expr;
+    return result;
+  }
+
+  Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
+    /* Helper for casting arguments to call_nodes handling all relevant cases. */
+    if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+      return CachedCast(expr, tensor_type->dtype, wanted_dtype);
+    } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
+        Expr tuple_element = GetField(expr, i);
+        Type tuple_element_dtype = (tuple_type->fields)[i];
+        Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(tuple_element);
+      }
+      return all_same ? expr : Tuple(new_expr);
+    } else {
+      LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
+      return expr;
+    }
+  }
+
+  std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args,
+                                                  const Array<Type>& cur_arg_types,
+                                                  const DataType& wanted_dtype) {
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    for (size_t i = 0; i < cur_args.size(); i++) {
+      Expr cur_arg = cur_args[i];
+      Type cur_arg_type = cur_arg_types[i];
+      Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+    }
+    return {new_args, new_arg_types};
+  }
+
+ public:
+  explicit AmpGraphCreator(ColorFunc colorer, OutputDtypeFunc output_dtype_func)
+      : ExprMutator(), colorer(colorer), output_dtype_func(output_dtype_func) {}
+
+  Expr VisitExpr_(const CallNode* call_node) {
+    FP16ConversionCategory initial_color = colorer(call_node);
+    auto new_op = this->Mutate(call_node->op);
+
+    // Mutate arguments to FP16 form first if possible and keep track of whether all floating point
+    // tensors are in FP16 form already. This is useful for propagating color.
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    bool all_args_fp16_compatible = true;
+    for (Expr arg : call_node->args) {
+      Expr new_arg = this->Mutate(arg);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+
+      if (all_args_fp16_compatible) {
+        // We can cast Vars and Constants to the right types so don't care about the types.
+        bool is_fp16_compatible = IsFP16Type(new_arg_type, true) || arg->IsInstance<VarNode>() ||
+                                  arg->IsInstance<ConstantNode>();
+        all_args_fp16_compatible &= is_fp16_compatible;
+      }
+    }
+
+    // Determine the final color.
+    FP16ConversionCategory final_color;
+    if (initial_color == GRAY) {
+      final_color = all_args_fp16_compatible ? GREEN : RED;

Review comment:
       Can you provide an example of FP16 incompatible?

##########
File path: src/relay/transforms/fp32_to_fp16.h
##########
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file fp32_to_fp16.h
+ * \brief Utilities and common types used for FP32->FP16 pass.
+ */
+#ifndef TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_
+#define TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_
+
+#include <tvm/ir/op.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/function.h>
+
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relay {
+
+struct FP16OpDType {
+  DataType accumulation_dtype;
+  DataType output_dtype;
+};
+
+// GREEN colored ops should always be done in FP16 due to the speed and memory savings
+// GRAY colored ops can be done in FP16 but don't have speedups to justify a dedicated cast.
+// RED colored ops should not be done in FP16 due to numerical reasons.
+enum FP16ConversionCategory { RED, GRAY, GREEN };
+
+using OpStringSet = std::unordered_set<std::string>;
+
+// Default lists inspired from TF's classifications:
+// github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h
+// They have a bias toward Nvidia Tensor Cores so modify lists per your hardware choice.
+OpStringSet DEFAULT_GREEN_LIST({
+    "nn.conv1d",
+    "nn.conv2d",
+    "nn.conv3d",
+    "nn.conv1d_transpose",
+    "nn.conv2d_transpose",
+    "nn.conv3d_transpose",
+    "nn.dense",
+    "nn.batch_matmul",
+});
+OpStringSet DEFAULT_GRAY_LIST({
+    // These ops add new data or change shape
+    "nn.pad",
+    "nn.batch_flatten",
+    "concatenate",
+    "zeros",
+    "split",
+    "squeeze",
+    "transpose",
+    "expand_dims",
+    "reshape",
+    "dyn.reshape",
+    "broadcast_to_like",
+    "dyn.broadcast_to",
+    "strided_slice",
+    "dyn.strided_slice",
+    "take",
+    "argwhere",
+    "where",
+    "tile",
+    "dyn.tile",
+    "scatter",
+    "full",
+    "dyn.full",
+    // Comparison
+    "less",
+    "greater",
+    "less_equal",
+    "greater_equal",
+    // By definition copy and cast will become green or red based on inputs
+    "copy",
+    "cast",
+    "cast_like",
+    // Simple arithmetic
+    "add",
+    "subtract",
+    "multiply",
+    "divide",
+    "nn.bias_add",
+    "nn.batch_norm",
+    "sum",
+    "mean",
+    "sqrt",
+    "shape_of",
+    // Simple activations
+    "max",
+    "min",
+    "maximum",
+    "minimum",
+    "nn.relu",
+    "nn.leaky_relu",
+    "nn.prelu",
+    "nn.dropout",
+    // Complicated activations which saturate in a narrow range
+    "sigmoid",
+    "tanh",
+    // Pooling operations
+    "nn.max_pool1d",
+    "nn.max_pool2d",
+    "nn.max_pool3d",
+    "nn.avg_pool1d",
+    "nn.avg_pool2d",
+    "nn.avg_pool3d",
+    // "nn.global_max_pool1d", // does not exist yet
+    "nn.global_max_pool2d",
+    // "nn.global_max_pool3d", // does not exist yet
+    // "nn.global_avg_pool1d", // does not exist yet
+    "nn.global_avg_pool2d",
+    // "nn.global_avg_pool3d", // does not exist yet
+    "nn.adaptive_max_pool1d",
+    "nn.adaptive_max_pool2d",
+    "nn.adaptive_max_pool3d",
+    "nn.adaptive_avg_pool1d",
+    "nn.adaptive_avg_pool2d",
+    "nn.adaptive_avg_pool3d",
+});
+OpStringSet DEFAULT_RED_LIST({
+    // In general if |f(x)| >> |x| for expected inputs then put the op here.
+    "exp",
+    "power",
+    "nn.cross_entropy",
+    "nn.cross_entropy_with_logits",
+    "nn.softmax",
+    "nn.l2_normalize",
+    // Error function doesn't seem to be able to be lowered into fp16 version in llvm.
+    // Move to gray list when it does.
+    "erf",
+});
+
+class DefaultFP16Colorer {
+  /* The default class to initially color ops for conversion using lists.
+
+  Creates a callable which given a CallNode* returns the node's color.
+  */
+ private:
+  std::unordered_map<std::string, FP16ConversionCategory> op_to_initial_color;
+
+ public:
+  DefaultFP16Colorer(OpStringSet red_list = DEFAULT_RED_LIST,
+                     OpStringSet gray_list = DEFAULT_GRAY_LIST,
+                     OpStringSet green_list = DEFAULT_GREEN_LIST) {
+    std::vector<std::pair<OpStringSet, FP16ConversionCategory>> lists_and_colors{
+        {red_list, RED}, {gray_list, GRAY}, {green_list, GREEN}};
+
+    for (auto list_and_color : lists_and_colors) {
+      OpStringSet ops = list_and_color.first;
+      FP16ConversionCategory color = list_and_color.second;
+      for (std::string op_name : ops) {
+        op_to_initial_color.insert({{op_name, color}});
+      }
+    }
+  }
+
+  FP16ConversionCategory operator()(const CallNode* call, bool ignore_missing = true) {
+    if (auto* op_node = (call->op).as<tvm::OpNode>()) {
+      std::string op_name = op_node->name;
+      auto color = op_to_initial_color.find(op_name);
+
+      if (color == op_to_initial_color.end()) {
+        if (ignore_missing) {
+          LOG(WARNING) << "Op name " << op_name << " not in included in fp16 conversion lists!";
+          return RED;
+        } else {
+          LOG(FATAL) << "Op name " << op_name << " not in included in fp16 lists!";
+        }
+      }
+
+      return color->second;
+    } else if ((call->op).as<FunctionNode>()) {
+      // Make RED to avoid messing with function headers.
+      return RED;
+    } else {
+      LOG(FATAL) << "FP16 conversion only supports call nodes with OpNodes or Functions got "
+                 << call->op;
+      return RED;
+    }
+  }
+};
+
+class DefaultFP16OpDefinition {
+  /* The default callable for determining accumulation_dtypes for ops. */
+ public:
+  FP16OpDType operator()(const CallNode* call) {
+    // TODO(AndrewZhaoLuo): remove when batch_matmul handles accumulation dtypes well.
+    // Batched matmul has inconsistent support for mixed precision operations.
+    // Many schedules ignore the out_dtype attribute which leads to errors when
+    // input types do not match the out_dtype. Therefore, accumulate to fp16 if green.
+    if (auto op_node = call->op.as<OpNode>()) {
+      if (op_node->name == "nn.batch_matmul") {
+        return {DataType::Float(16), DataType::Float(16)};
+      }
+    }

Review comment:
       This again illustrates the importance of registering casting function to each op instead of here.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] comaniac commented on a change in pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r648620237



##########
File path: src/relay/transforms/fp32_to_fp16.cc
##########
@@ -0,0 +1,337 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file fp32_to_fp16.cc
+ * \brief Rewrite a graph into an fp16 form.
+ */
+#include "fp32_to_fp16.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// A function which maps CallNodes to their initial conversion color
+using ColorFunc = std::function<FP16ConversionCategory(const CallNode*)>;
+
+// A function which maps green CallNodes to wanted accumulation and output dtypes
+using OutputDtypeFunc = std::function<FP16OpDType(const CallNode*)>;
+
+class AmpGraphCreator : public ExprMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+  const ColorFunc colorer;
+  const OutputDtypeFunc output_dtype_func;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs new_attrs = Attrs(call->attrs);
+    if (new_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = new_attrs.as<Conv1DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv1DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DeformableConv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DenseAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<BatchMatmulAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = new_attrs.as<InitOpAttrs>()) {
+        ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return new_attrs;
+  }
+
+  template <typename T>
+  void ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    T* mutable_attrs = const_cast<T*>(attrs);
+
+    DataType cur_type = (mutable_attrs->out_dtype);
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->out_dtype = accumulation_dtype;
+  }
+
+  template <typename T>
+  void ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    T* mutable_attrs = const_cast<T*>(attrs);
+    DataType cur_type = (mutable_attrs->dtype);
+
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->dtype = accumulation_dtype;
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsFP16Type(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only fp16 elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == DataType::Float(16);
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsFP16Type(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = expr_dtype == wanted_dtype ? expr : Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;
+
+    // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
+    const ExprNode* new_expr_node = result.as<ExprNode>();
+    cast_nodes_cache[{new_expr_node, expr_dtype}] = expr;
+    return result;
+  }
+
+  Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
+    /* Helper for casting arguments to call_nodes handling all relevant cases. */
+    if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+      return CachedCast(expr, tensor_type->dtype, wanted_dtype);
+    } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
+        Expr tuple_element = GetField(expr, i);
+        Type tuple_element_dtype = (tuple_type->fields)[i];
+        Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(tuple_element);
+      }
+      return all_same ? expr : Tuple(new_expr);
+    } else {
+      LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
+      return expr;
+    }
+  }
+
+  std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args,
+                                                  const Array<Type>& cur_arg_types,
+                                                  const DataType& wanted_dtype) {
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    for (size_t i = 0; i < cur_args.size(); i++) {
+      Expr cur_arg = cur_args[i];
+      Type cur_arg_type = cur_arg_types[i];
+      Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+    }
+    return {new_args, new_arg_types};
+  }
+
+ public:
+  explicit AmpGraphCreator(ColorFunc colorer, OutputDtypeFunc output_dtype_func)
+      : ExprMutator(), colorer(colorer), output_dtype_func(output_dtype_func) {}
+
+  Expr VisitExpr_(const CallNode* call_node) {
+    FP16ConversionCategory initial_color = colorer(call_node);
+    auto new_op = this->Mutate(call_node->op);
+
+    // Mutate arguments to FP16 form first if possible and keep track of whether all floating point
+    // tensors are in FP16 form already. This is useful for propagating color.
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    bool all_args_fp16_compatible = true;
+    for (Expr arg : call_node->args) {
+      Expr new_arg = this->Mutate(arg);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+
+      if (all_args_fp16_compatible) {
+        // We can cast Vars and Constants to the right types so don't care about the types.
+        bool is_fp16_compatible = IsFP16Type(new_arg_type, true) || arg->IsInstance<VarNode>() ||
+                                  arg->IsInstance<ConstantNode>();
+        all_args_fp16_compatible &= is_fp16_compatible;
+      }
+    }
+
+    // Determine the final color.
+    FP16ConversionCategory final_color;
+    if (initial_color == GRAY) {
+      final_color = all_args_fp16_compatible ? GREEN : RED;
+    } else {
+      final_color = initial_color;
+    }
+
+    // Create the new arguments to the call.
+    DataType wanted_arg_dtypes = final_color == GREEN ? DataType::Float(16) : DataType::Float(32);
+    auto call_args_and_types = CastAllArgs(new_args, new_arg_types, wanted_arg_dtypes);
+
+    Array<Expr> call_args = call_args_and_types.first;
+    Array<Type> call_arg_types;
+
+    if (call_node->op.as<FunctionNode>()) {
+      // Function Nodes don't store type info in the Call, it should be a []
+      call_arg_types = call_node->type_args;
+    } else {
+      call_arg_types = call_args_and_types.second;
+    }
+
+    // Finally create the new attributes.
+    if (final_color == GREEN) {
+      FP16OpDType output_dtypes = output_dtype_func(call_node);
+
+      Attrs new_attrs = GetNewAttrs(call_node, output_dtypes.accumulation_dtype);
+      Expr output = Call(new_op, call_args, new_attrs, call_arg_types, call_node->span);
+      if (output_dtypes.accumulation_dtype != output_dtypes.output_dtype) {
+        output = CastArg(output, GetType(output), output_dtypes.output_dtype);

Review comment:
       I see...this mechanism is interesting and I haven't paid too much attention on it. At the first glance, I would worry if the cache will blow up when the model is too large, but I'll probably take a deeper look at this mechanism later.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on a change in pull request #8069: [Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r652134539



##########
File path: python/tvm/relay/transform/transform.py
##########
@@ -1199,3 +1198,22 @@ def FakeQuantizationToInteger():
         The registered SimplifyExpr pass.
     """
     return _ffi_api.FakeQuantizationToInteger()
+
+
+def ToMixedPrecision(
+    mixed_precision_type="float16", ignore_missing_ops=True, warn_missing_ops=True
+):
+    """
+    Automatic mixed precision rewriter. Rewrite an FP32 relay graph into a version
+    where as many operations as possible are in the target mixed_precision_type.
+
+    Note this does mutate the original graph putting it in a bad state potentially.
+
+    TODO(AndrewZhaoLuo): don't mutate the original graph.

Review comment:
       That problem is pretty old. It doesn't seem to have the problem anymore so removed.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on a change in pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r648649618



##########
File path: python/tvm/relay/transform/transform.py
##########
@@ -1199,3 +1198,20 @@ def FakeQuantizationToInteger():
         The registered SimplifyExpr pass.
     """
     return _ffi_api.FakeQuantizationToInteger()
+
+
+def AMPRewrite():

Review comment:
       Hmm, I would prefer `ToMixedPrecision` still if that is fine with you.
   
   The example you list only works for me because it exists under the `amp` namespace. `AutoCast` by itself without being part of `torch.cuda.amp` does not show mixed precision.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on a change in pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r646890463



##########
File path: src/relay/transforms/fp32_to_fp16.h
##########
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file fp32_to_fp16.h
+ * \brief Utilities and common types used for FP32->FP16 pass.
+ */
+#ifndef TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_
+#define TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_
+
+#include <tvm/ir/op.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/function.h>
+
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relay {
+
+struct FP16OpDType {
+  DataType accumulation_dtype;
+  DataType output_dtype;
+};
+
+// GREEN colored ops should always be done in FP16 due to the speed and memory savings
+// GRAY colored ops can be done in FP16 but don't have speedups to justify a dedicated cast.
+// RED colored ops should not be done in FP16 due to numerical reasons.
+enum FP16ConversionCategory { RED, GRAY, GREEN };

Review comment:
       By default it would be RED and a warning would be emitted.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on a change in pull request #8069: [Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r652189117



##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,409 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed floating point precision for relay graphs. i.e. turn a graph into fp16.
+ *
+ */
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
+// savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to
+// justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to
+// numerical reasons.
+enum MixedTypeConversionCategory : int {
+  MIXED_PRECISION_ALWAYS = 0,
+  MIXED_PRECISION_FOLLOW = 1,
+  MIXED_PRECISION_NEVER = 2
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// Return array is of type : [MixedTypeConversionCategory (int), String, String]
+// The fields are          : [ConversionCategory, accumulation_datatype, output_datatype]
+// Call is a call node, DataType is the mixed precision type
+using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc<Array<ObjectRef>(
+    const Call& call_node, const std::string& target_dtype_str)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+
+  // The target datatype we want to convert to e.g. FP16
+  const DataType mixed_precision_type;
+
+  // If false, throws a fatal error if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool ignore_missing_ops;
+
+  // If true, emits a warning if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool warn_missing_ops;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs cur_attrs = call->attrs;
+    if (cur_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = cur_attrs.as<Conv1DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv1DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DeformableConv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DenseAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<BatchMatmulAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = cur_attrs.as<InitOpAttrs>()) {
+        return ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return cur_attrs;
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    DataType cur_type = (attrs->out_dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->out_dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    DataType cur_type = (attrs->dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only target mixed precision type elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == mixed_precision_type;
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsMixedPrecisionType(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    if (expr_dtype == wanted_dtype) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;
+
+    // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
+    const ExprNode* new_expr_node = result.as<ExprNode>();
+    cast_nodes_cache[{new_expr_node, expr_dtype}] = expr;
+    return result;
+  }
+
+  Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
+    /* Helper for casting arguments to call_nodes handling all relevant cases. */
+    if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+      return CachedCast(expr, tensor_type->dtype, wanted_dtype);
+    } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
+        Expr tuple_element = GetField(expr, i);
+        Type tuple_element_dtype = (tuple_type->fields)[i];
+        Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(tuple_element);
+      }
+      return all_same ? expr : Tuple(new_expr);
+    } else {
+      LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
+      return expr;
+    }
+  }
+
+  std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args,
+                                                  const Array<Type>& cur_arg_types,
+                                                  const DataType& wanted_dtype) {
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    for (size_t i = 0; i < cur_args.size(); i++) {
+      Expr cur_arg = cur_args[i];
+      Type cur_arg_type = cur_arg_types[i];
+      Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+    }
+    return {new_args, new_arg_types};
+  }
+
+ public:
+  using MixedModeMutator::VisitExpr_;
+
+  explicit MixedPrecisionPass(DataType mixed_precision_type = DataType::Float(16),
+                              bool ignore_missing_ops = true, bool warn_missing_ops = true)
+      : MixedModeMutator(),
+        mixed_precision_type(mixed_precision_type),
+        ignore_missing_ops(ignore_missing_ops),
+        warn_missing_ops(warn_missing_ops) {
+    if (!mixed_precision_type.is_float() && !mixed_precision_type.is_bfloat16())
+      LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat16 got "
+                 << mixed_precision_type;
+  }
+
+  Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final {
+    const CallNode* post_call_node = post.as<CallNode>();
+    if (!post_call_node) {
+      LOG(FATAL) << "Expected a CallNode for the rewrite got " << post;
+    }
+
+    Expr cur_op = post_call_node->op;
+
+    // Get info on the operation being called:
+    // conversion category (int), accumulation dtype (str), output dtype (str)
+    MixedTypeConversionCategory initial_category;
+    DataType accumulation_dtype, output_dtype;
+    if (cur_op.as<FunctionNode>()) {
+      // Avoid messing with functions to avoid changing signature
+      initial_category = MIXED_PRECISION_NEVER;
+      accumulation_dtype = DataType::Float(32);
+      output_dtype = DataType::Float(32);
+    } else if (cur_op.as<OpNode>()) {
+      static auto attr_map =
+          Op::GetAttrMap<FTVMMixedPrecisionConversionType>("FTVMMixedPrecisionConversionType");
+      Op op = Downcast<Op>(cur_op);
+      if (attr_map.count(op)) {
+        // Calculate the conversion category and dtypes from registered attribute.
+        FTVMMixedPrecisionConversionType func = attr_map[op];
+        Array<ObjectRef> op_descriptor =
+            func(GetRef<Call>(pre_call_node), DLDataType2String(mixed_precision_type));

Review comment:
       The DownCasts dynamically check. I added verification on the size of the returned array.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on pull request #8069: [Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#issuecomment-863603673


   @mbrookhart PTAL


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on a change in pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r646890115



##########
File path: python/tvm/relay/transform/transform.py
##########
@@ -1145,6 +1144,21 @@ def AnnotateSpans():
     Returns
     -------
     ret : tvm.transform.Pass
-        The regsistered AnnotateSpans pass.
+        The registered AnnotateSpans pass.
     """
     return _ffi_api.AnnotateSpans()
+
+
+def RewriteFP16():

Review comment:
       Good idea. Done.

##########
File path: src/relay/transforms/fp32_to_fp16.h
##########
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file fp32_to_fp16.h
+ * \brief Utilities and common types used for FP32->FP16 pass.
+ */
+#ifndef TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_
+#define TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_
+
+#include <tvm/ir/op.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/function.h>
+
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relay {
+
+struct FP16OpDType {
+  DataType accumulation_dtype;
+  DataType output_dtype;
+};
+
+// GREEN colored ops should always be done in FP16 due to the speed and memory savings
+// GRAY colored ops can be done in FP16 but don't have speedups to justify a dedicated cast.
+// RED colored ops should not be done in FP16 due to numerical reasons.
+enum FP16ConversionCategory { RED, GRAY, GREEN };
+
+using OpStringSet = std::unordered_set<std::string>;
+
+// Default lists inspired from TF's classifications:
+// github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h
+// They have a bias toward Nvidia Tensor Cores so modify lists per your hardware choice.
+OpStringSet DEFAULT_GREEN_LIST({
+    "nn.conv1d",
+    "nn.conv2d",
+    "nn.conv3d",
+    "nn.conv1d_transpose",
+    "nn.conv2d_transpose",
+    "nn.conv3d_transpose",
+    "nn.dense",
+    "nn.batch_matmul",
+});
+OpStringSet DEFAULT_GRAY_LIST({
+    // These ops add new data or change shape
+    "nn.pad",
+    "nn.batch_flatten",
+    "concatenate",
+    "zeros",
+    "split",
+    "squeeze",
+    "transpose",
+    "expand_dims",
+    "reshape",
+    "dyn.reshape",
+    "broadcast_to_like",
+    "dyn.broadcast_to",
+    "strided_slice",
+    "dyn.strided_slice",
+    "take",
+    "argwhere",
+    "where",
+    "tile",
+    "dyn.tile",
+    "scatter",
+    "full",
+    "dyn.full",
+    // Comparison
+    "less",
+    "greater",
+    "less_equal",
+    "greater_equal",
+    // By definition copy and cast will become green or red based on inputs
+    "copy",
+    "cast",
+    "cast_like",
+    // Simple arithmetic
+    "add",
+    "subtract",
+    "multiply",
+    "divide",
+    "nn.bias_add",
+    "nn.batch_norm",
+    "sum",
+    "mean",
+    "sqrt",
+    "shape_of",
+    // Simple activations
+    "max",
+    "min",
+    "maximum",
+    "minimum",
+    "nn.relu",
+    "nn.leaky_relu",
+    "nn.prelu",
+    "nn.dropout",
+    // Complicated activations which saturate in a narrow range
+    "sigmoid",
+    "tanh",
+    // Pooling operations
+    "nn.max_pool1d",
+    "nn.max_pool2d",
+    "nn.max_pool3d",
+    "nn.avg_pool1d",
+    "nn.avg_pool2d",
+    "nn.avg_pool3d",
+    // "nn.global_max_pool1d", // does not exist yet
+    "nn.global_max_pool2d",
+    // "nn.global_max_pool3d", // does not exist yet
+    // "nn.global_avg_pool1d", // does not exist yet
+    "nn.global_avg_pool2d",
+    // "nn.global_avg_pool3d", // does not exist yet
+    "nn.adaptive_max_pool1d",
+    "nn.adaptive_max_pool2d",
+    "nn.adaptive_max_pool3d",
+    "nn.adaptive_avg_pool1d",
+    "nn.adaptive_avg_pool2d",
+    "nn.adaptive_avg_pool3d",
+});
+OpStringSet DEFAULT_RED_LIST({
+    // In general if |f(x)| >> |x| for expected inputs then put the op here.
+    "exp",
+    "power",
+    "nn.cross_entropy",
+    "nn.cross_entropy_with_logits",
+    "nn.softmax",
+    "nn.l2_normalize",
+    // Error function doesn't seem to be able to be lowered into fp16 version in llvm.
+    // Move to gray list when it does.
+    "erf",
+});
+
+class DefaultFP16Colorer {
+  /* The default class to initially color ops for conversion using lists.
+
+  Creates a callable which given a CallNode* returns the node's color.
+  */
+ private:
+  std::unordered_map<std::string, FP16ConversionCategory> op_to_initial_color;
+
+ public:
+  DefaultFP16Colorer(OpStringSet red_list = DEFAULT_RED_LIST,
+                     OpStringSet gray_list = DEFAULT_GRAY_LIST,
+                     OpStringSet green_list = DEFAULT_GREEN_LIST) {
+    std::vector<std::pair<OpStringSet, FP16ConversionCategory>> lists_and_colors{
+        {red_list, RED}, {gray_list, GRAY}, {green_list, GREEN}};
+
+    for (auto list_and_color : lists_and_colors) {
+      OpStringSet ops = list_and_color.first;
+      FP16ConversionCategory color = list_and_color.second;
+      for (std::string op_name : ops) {
+        op_to_initial_color.insert({{op_name, color}});
+      }
+    }
+  }
+
+  FP16ConversionCategory operator()(const CallNode* call, bool ignore_missing = true) {
+    if (auto* op_node = (call->op).as<tvm::OpNode>()) {
+      std::string op_name = op_node->name;
+      auto color = op_to_initial_color.find(op_name);
+
+      if (color == op_to_initial_color.end()) {
+        if (ignore_missing) {
+          LOG(WARNING) << "Op name " << op_name << " not in included in fp16 conversion lists!.";

Review comment:
       Done

##########
File path: src/relay/transforms/fp32_to_fp16.h
##########
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file fp32_to_fp16.h
+ * \brief Utilities and common types used for FP32->FP16 pass.
+ */
+#ifndef TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_
+#define TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_
+
+#include <tvm/ir/op.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/function.h>
+
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relay {
+
+struct FP16OpDType {
+  DataType accumulation_dtype;
+  DataType output_dtype;
+};
+
+// GREEN colored ops should always be done in FP16 due to the speed and memory savings
+// GRAY colored ops can be done in FP16 but don't have speedups to justify a dedicated cast.
+// RED colored ops should not be done in FP16 due to numerical reasons.
+enum FP16ConversionCategory { RED, GRAY, GREEN };
+
+using OpStringSet = std::unordered_set<std::string>;
+
+// Default lists inspired from TF's classifications:
+// github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h
+// They have a bias toward Nvidia Tensor Cores so modify lists per your hardware choice.
+OpStringSet DEFAULT_GREEN_LIST({
+    "nn.conv1d",
+    "nn.conv2d",
+    "nn.conv3d",
+    "nn.conv1d_transpose",
+    "nn.conv2d_transpose",
+    "nn.conv3d_transpose",
+    "nn.dense",
+    "nn.batch_matmul",
+});
+OpStringSet DEFAULT_GRAY_LIST({
+    // These ops add new data or change shape
+    "nn.pad",
+    "nn.batch_flatten",
+    "concatenate",
+    "zeros",
+    "split",
+    "squeeze",
+    "transpose",
+    "expand_dims",
+    "reshape",
+    "dyn.reshape",
+    "broadcast_to_like",
+    "dyn.broadcast_to",
+    "strided_slice",
+    "dyn.strided_slice",
+    "take",
+    "argwhere",
+    "where",
+    "tile",
+    "dyn.tile",
+    "scatter",
+    "full",
+    "dyn.full",
+    // Comparison
+    "less",
+    "greater",
+    "less_equal",
+    "greater_equal",
+    // By definition copy and cast will become green or red based on inputs
+    "copy",
+    "cast",
+    "cast_like",
+    // Simple arithmetic
+    "add",
+    "subtract",
+    "multiply",
+    "divide",
+    "nn.bias_add",
+    "nn.batch_norm",
+    "sum",
+    "mean",
+    "sqrt",
+    "shape_of",
+    // Simple activations
+    "max",
+    "min",
+    "maximum",
+    "minimum",
+    "nn.relu",
+    "nn.leaky_relu",
+    "nn.prelu",
+    "nn.dropout",
+    // Complicated activations which saturate in a narrow range
+    "sigmoid",
+    "tanh",
+    // Pooling operations
+    "nn.max_pool1d",
+    "nn.max_pool2d",
+    "nn.max_pool3d",
+    "nn.avg_pool1d",
+    "nn.avg_pool2d",
+    "nn.avg_pool3d",
+    // "nn.global_max_pool1d", // does not exist yet
+    "nn.global_max_pool2d",
+    // "nn.global_max_pool3d", // does not exist yet
+    // "nn.global_avg_pool1d", // does not exist yet
+    "nn.global_avg_pool2d",
+    // "nn.global_avg_pool3d", // does not exist yet
+    "nn.adaptive_max_pool1d",
+    "nn.adaptive_max_pool2d",
+    "nn.adaptive_max_pool3d",
+    "nn.adaptive_avg_pool1d",
+    "nn.adaptive_avg_pool2d",
+    "nn.adaptive_avg_pool3d",
+});
+OpStringSet DEFAULT_RED_LIST({
+    // In general if |f(x)| >> |x| for expected inputs then put the op here.
+    "exp",
+    "power",
+    "nn.cross_entropy",
+    "nn.cross_entropy_with_logits",
+    "nn.softmax",
+    "nn.l2_normalize",
+    // Error function doesn't seem to be able to be lowered into fp16 version in llvm.
+    // Move to gray list when it does.
+    "erf",
+});
+
+class DefaultFP16Colorer {
+  /* The default class to initially color ops for conversion using lists.
+
+  Creates a callable which given a CallNode* returns the node's color.
+  */
+ private:
+  std::unordered_map<std::string, FP16ConversionCategory> op_to_initial_color;
+
+ public:
+  DefaultFP16Colorer(OpStringSet red_list = DEFAULT_RED_LIST,
+                     OpStringSet gray_list = DEFAULT_GRAY_LIST,
+                     OpStringSet green_list = DEFAULT_GREEN_LIST) {
+    std::vector<std::pair<OpStringSet, FP16ConversionCategory>> lists_and_colors{
+        {red_list, RED}, {gray_list, GRAY}, {green_list, GREEN}};
+
+    for (auto list_and_color : lists_and_colors) {
+      OpStringSet ops = list_and_color.first;
+      FP16ConversionCategory color = list_and_color.second;
+      for (std::string op_name : ops) {
+        op_to_initial_color.insert({{op_name, color}});
+      }
+    }
+  }
+
+  FP16ConversionCategory operator()(const CallNode* call, bool ignore_missing = true) {
+    if (auto* op_node = (call->op).as<tvm::OpNode>()) {
+      std::string op_name = op_node->name;
+      auto color = op_to_initial_color.find(op_name);
+
+      if (color == op_to_initial_color.end()) {
+        if (ignore_missing) {
+          LOG(WARNING) << "Op name " << op_name << " not in included in fp16 conversion lists!.";
+          return RED;
+        } else {
+          LOG(FATAL) << "Op name " << op_name << " not in included in fp16 lists!.";

Review comment:
       Done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on a change in pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r648598356



##########
File path: src/relay/transforms/fp32_to_fp16.cc
##########
@@ -0,0 +1,337 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file fp32_to_fp16.cc

Review comment:
       I'll change it to something like `amp.cc` when we decide what we want to call the pass. 
   
   I would like the file names to match closely to the user interface name. E.g. `ToMixedPrecision` --> `to_mixed_precision.cc`




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on a change in pull request #8069: [Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r652184550



##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,409 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed floating point precision for relay graphs. i.e. turn a graph into fp16.
+ *
+ */
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
+// savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to
+// justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to
+// numerical reasons.
+enum MixedTypeConversionCategory : int {
+  MIXED_PRECISION_ALWAYS = 0,
+  MIXED_PRECISION_FOLLOW = 1,
+  MIXED_PRECISION_NEVER = 2
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// Return array is of type : [MixedTypeConversionCategory (int), String, String]
+// The fields are          : [ConversionCategory, accumulation_datatype, output_datatype]
+// Call is a call node, DataType is the mixed precision type
+using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc<Array<ObjectRef>(
+    const Call& call_node, const std::string& target_dtype_str)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;

Review comment:
       done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on a change in pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r648601951



##########
File path: src/relay/transforms/fp32_to_fp16.h
##########
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file fp32_to_fp16.h
+ * \brief Utilities and common types used for FP32->FP16 pass.
+ */
+#ifndef TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_
+#define TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_
+
+#include <tvm/ir/op.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/function.h>
+
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relay {
+
+struct FP16OpDType {
+  DataType accumulation_dtype;
+  DataType output_dtype;
+};
+
+// GREEN colored ops should always be done in FP16 due to the speed and memory savings
+// GRAY colored ops can be done in FP16 but don't have speedups to justify a dedicated cast.
+// RED colored ops should not be done in FP16 due to numerical reasons.
+enum FP16ConversionCategory { RED, GRAY, GREEN };
+
+using OpStringSet = std::unordered_set<std::string>;
+
+// Default lists inspired from TF's classifications:

Review comment:
       Good advices. 
   
   I'll use better terms instead of RED/GRAY/GREEN.
   I'll also make the warning messages configurable to the user.
   For the registering attributes to each op, I think it's probably a good idea but do you have an example of this strategy I could look at?
   The user defined rules from python is a goal I will try for. It might take a little longer though.
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on a change in pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r647776445



##########
File path: src/relay/transforms/fp32_to_fp16.cc
##########
@@ -0,0 +1,332 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file fp32_to_fp16.cc
+ * \brief Rewrite a graph into an fp16 form.
+ */
+#include "fp32_to_fp16.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    return h1 ^ (h2 << 1);
+  }
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// A function which maps CallNodes to their initial conversion color
+using ColorFunc = std::function<FP16ConversionCategory(const CallNode*)>;
+
+// A function which maps green CallNodes to wanted accumulation and output dtypes
+using OutputDtypeFunc = std::function<FP16OpDType(const CallNode*)>;
+
+class AmpGraphCreator : public ExprMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+  const ColorFunc colorer;
+  const OutputDtypeFunc output_dtype_func;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs new_attrs = Attrs(call->attrs);
+    if (new_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = new_attrs.as<Conv1DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv1DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DeformableConv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DenseAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<BatchMatmulAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = new_attrs.as<InitOpAttrs>()) {
+        ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return new_attrs;
+  }
+
+  template <typename T>
+  void ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    T* mutable_attrs = const_cast<T*>(attrs);
+    if ((mutable_attrs->out_dtype).is_float()) mutable_attrs->out_dtype = accumulation_dtype;
+  }
+
+  template <typename T>
+  void ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    T* mutable_attrs = const_cast<T*>(attrs);
+    if ((mutable_attrs->dtype).is_float()) mutable_attrs->dtype = accumulation_dtype;
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsFP16Type(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only fp16 elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == DataType::Float(16);
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsFP16Type(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = expr_dtype == wanted_dtype ? expr : Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;
+
+    // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
+    const ExprNode* new_expr_node = result.as<ExprNode>();
+    cast_nodes_cache[{new_expr_node, expr_dtype}] = expr;
+    return result;
+  }
+
+  Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
+    /* Helper for casting arguments to call_nodes handling all relevant cases. */
+    if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+      return CachedCast(expr, tensor_type->dtype, wanted_dtype);
+    } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
+        Expr tuple_element = GetField(expr, i);
+        Type tuple_element_dtype = (tuple_type->fields)[i];
+        Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(tuple_element);
+      }
+      return all_same ? expr : Tuple(new_expr);
+    } else {
+      LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
+      return expr;
+    }
+  }
+
+  std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args,
+                                                  const Array<Type>& cur_arg_types,
+                                                  const DataType& wanted_dtype) {
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    for (size_t i = 0; i < cur_args.size(); i++) {
+      Expr cur_arg = cur_args[i];
+      Type cur_arg_type = cur_arg_types[i];
+      Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+    }
+    return {new_args, new_arg_types};
+  }
+
+ public:
+  explicit AmpGraphCreator(ColorFunc colorer, OutputDtypeFunc output_dtype_func)
+      : ExprMutator(), colorer(colorer), output_dtype_func(output_dtype_func) {}

Review comment:
       Yes right now I depend on a post order traversal actually (since we want all arguments to call nodes to be mutated before we make a decision on whether to convert a call node to fp16). I'll look into MixedMode Mutator to solve this issue.

##########
File path: tests/python/frontend/mxnet/test_forward.py
##########
@@ -1223,6 +1221,8 @@ def verify(shape, axis=1, fix_gamma=False):
 
 @tvm.testing.uses_gpu
 def test_forward_instance_norm():
+    np.random.seed(90)
+

Review comment:
       Oh ok that's an interesting idea. I had a failure where the passing rtol was 1.05e-5 so I'm just going to lower the tolerance.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on a change in pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r648611836



##########
File path: src/relay/transforms/fp32_to_fp16.cc
##########
@@ -0,0 +1,337 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file fp32_to_fp16.cc
+ * \brief Rewrite a graph into an fp16 form.
+ */
+#include "fp32_to_fp16.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// A function which maps CallNodes to their initial conversion color
+using ColorFunc = std::function<FP16ConversionCategory(const CallNode*)>;
+
+// A function which maps green CallNodes to wanted accumulation and output dtypes
+using OutputDtypeFunc = std::function<FP16OpDType(const CallNode*)>;
+
+class AmpGraphCreator : public ExprMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+  const ColorFunc colorer;
+  const OutputDtypeFunc output_dtype_func;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs new_attrs = Attrs(call->attrs);
+    if (new_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = new_attrs.as<Conv1DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv1DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DeformableConv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DenseAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<BatchMatmulAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = new_attrs.as<InitOpAttrs>()) {
+        ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return new_attrs;
+  }
+
+  template <typename T>
+  void ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    T* mutable_attrs = const_cast<T*>(attrs);
+
+    DataType cur_type = (mutable_attrs->out_dtype);
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->out_dtype = accumulation_dtype;
+  }
+
+  template <typename T>
+  void ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    T* mutable_attrs = const_cast<T*>(attrs);
+    DataType cur_type = (mutable_attrs->dtype);
+
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->dtype = accumulation_dtype;
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsFP16Type(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only fp16 elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == DataType::Float(16);
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsFP16Type(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = expr_dtype == wanted_dtype ? expr : Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;
+
+    // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
+    const ExprNode* new_expr_node = result.as<ExprNode>();
+    cast_nodes_cache[{new_expr_node, expr_dtype}] = expr;
+    return result;
+  }
+
+  Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
+    /* Helper for casting arguments to call_nodes handling all relevant cases. */
+    if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+      return CachedCast(expr, tensor_type->dtype, wanted_dtype);
+    } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
+        Expr tuple_element = GetField(expr, i);
+        Type tuple_element_dtype = (tuple_type->fields)[i];
+        Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(tuple_element);
+      }
+      return all_same ? expr : Tuple(new_expr);
+    } else {
+      LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
+      return expr;
+    }
+  }
+
+  std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args,
+                                                  const Array<Type>& cur_arg_types,
+                                                  const DataType& wanted_dtype) {
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    for (size_t i = 0; i < cur_args.size(); i++) {
+      Expr cur_arg = cur_args[i];
+      Type cur_arg_type = cur_arg_types[i];
+      Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+    }
+    return {new_args, new_arg_types};
+  }
+
+ public:
+  explicit AmpGraphCreator(ColorFunc colorer, OutputDtypeFunc output_dtype_func)
+      : ExprMutator(), colorer(colorer), output_dtype_func(output_dtype_func) {}
+
+  Expr VisitExpr_(const CallNode* call_node) {
+    FP16ConversionCategory initial_color = colorer(call_node);
+    auto new_op = this->Mutate(call_node->op);
+
+    // Mutate arguments to FP16 form first if possible and keep track of whether all floating point
+    // tensors are in FP16 form already. This is useful for propagating color.
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    bool all_args_fp16_compatible = true;
+    for (Expr arg : call_node->args) {
+      Expr new_arg = this->Mutate(arg);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+
+      if (all_args_fp16_compatible) {
+        // We can cast Vars and Constants to the right types so don't care about the types.
+        bool is_fp16_compatible = IsFP16Type(new_arg_type, true) || arg->IsInstance<VarNode>() ||
+                                  arg->IsInstance<ConstantNode>();
+        all_args_fp16_compatible &= is_fp16_compatible;
+      }
+    }
+
+    // Determine the final color.
+    FP16ConversionCategory final_color;
+    if (initial_color == GRAY) {
+      final_color = all_args_fp16_compatible ? GREEN : RED;
+    } else {
+      final_color = initial_color;
+    }
+
+    // Create the new arguments to the call.
+    DataType wanted_arg_dtypes = final_color == GREEN ? DataType::Float(16) : DataType::Float(32);
+    auto call_args_and_types = CastAllArgs(new_args, new_arg_types, wanted_arg_dtypes);
+
+    Array<Expr> call_args = call_args_and_types.first;
+    Array<Type> call_arg_types;
+
+    if (call_node->op.as<FunctionNode>()) {
+      // Function Nodes don't store type info in the Call, it should be a []
+      call_arg_types = call_node->type_args;
+    } else {
+      call_arg_types = call_args_and_types.second;
+    }
+
+    // Finally create the new attributes.
+    if (final_color == GREEN) {
+      FP16OpDType output_dtypes = output_dtype_func(call_node);
+
+      Attrs new_attrs = GetNewAttrs(call_node, output_dtypes.accumulation_dtype);
+      Expr output = Call(new_op, call_args, new_attrs, call_arg_types, call_node->span);
+      if (output_dtypes.accumulation_dtype != output_dtypes.output_dtype) {
+        output = CastArg(output, GetType(output), output_dtypes.output_dtype);

Review comment:
       Hmm I don't believe so since `CachedCast` will also cache the reverse result.
   
   E.g. `CachedCast(A(Green), FP16)` would produce `A(GREEN) - cast_to_fp16` 
   
   But internally it would cache:
   `Node, wanted_dtype`
   `A(GREEN), FP16` --> `cast_to_fp16`
   `cast_to_fp16, FP32` --> `A(GREEN)`
   
   So attempting to cast `cast_to_fp16` to `fp32` would return `A(GREEN)`
   
   It would be worth having a test case to cover this however and make sure.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] comaniac commented on pull request #8069: [Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass

Posted by GitBox <gi...@apache.org>.
comaniac commented on pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#issuecomment-878678403


   - You might need to rebuild TVM.
   - Please do not ask usage questions under feature PRs. Please post the question to the discuss forum instead.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbrookhart commented on a change in pull request #8069: [Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r652142972



##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,409 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed floating point precision for relay graphs. i.e. turn a graph into fp16.
+ *
+ */
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
+// savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to
+// justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to
+// numerical reasons.
+enum MixedTypeConversionCategory : int {
+  MIXED_PRECISION_ALWAYS = 0,
+  MIXED_PRECISION_FOLLOW = 1,
+  MIXED_PRECISION_NEVER = 2
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// Return array is of type : [MixedTypeConversionCategory (int), String, String]
+// The fields are          : [ConversionCategory, accumulation_datatype, output_datatype]
+// Call is a call node, DataType is the mixed precision type
+using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc<Array<ObjectRef>(
+    const Call& call_node, const std::string& target_dtype_str)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+
+  // The target datatype we want to convert to e.g. FP16
+  const DataType mixed_precision_type;
+
+  // If false, throws a fatal error if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool ignore_missing_ops;
+
+  // If true, emits a warning if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool warn_missing_ops;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs cur_attrs = call->attrs;
+    if (cur_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = cur_attrs.as<Conv1DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv1DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DeformableConv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DenseAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<BatchMatmulAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = cur_attrs.as<InitOpAttrs>()) {
+        return ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return cur_attrs;
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    DataType cur_type = (attrs->out_dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->out_dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    DataType cur_type = (attrs->dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }

Review comment:
       Let me see if I can find the unit test I hit this on...




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] AndrewZhaoLuo commented on a change in pull request #8069: [Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r652182389



##########
File path: src/relay/transforms/to_mixed_precision.cc
##########
@@ -0,0 +1,409 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file to_mixed_precision.cc
+ * \brief Automatic mixed floating point precision for relay graphs. i.e. turn a graph into fp16.
+ *
+ */
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory
+// savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to
+// justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to
+// numerical reasons.
+enum MixedTypeConversionCategory : int {
+  MIXED_PRECISION_ALWAYS = 0,
+  MIXED_PRECISION_FOLLOW = 1,
+  MIXED_PRECISION_NEVER = 2
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// Return array is of type : [MixedTypeConversionCategory (int), String, String]
+// The fields are          : [ConversionCategory, accumulation_datatype, output_datatype]
+// Call is a call node, DataType is the mixed precision type
+using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc<Array<ObjectRef>(
+    const Call& call_node, const std::string& target_dtype_str)>;
+
+class MixedPrecisionPass : public MixedModeMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+
+  // The target datatype we want to convert to e.g. FP16
+  const DataType mixed_precision_type;
+
+  // If false, throws a fatal error if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool ignore_missing_ops;
+
+  // If true, emits a warning if an op which is not registered with a
+  // FTVMMixedPrecisionConversionType is encountered.
+  bool warn_missing_ops;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs cur_attrs = call->attrs;
+    if (cur_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = cur_attrs.as<Conv1DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv1DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DeformableConv2DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DTransposeAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<Conv3DWinogradAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<DenseAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = cur_attrs.as<BatchMatmulAttrs>()) {
+        return ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = cur_attrs.as<InitOpAttrs>()) {
+        return ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return cur_attrs;
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    DataType cur_type = (attrs->out_dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->out_dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  template <typename T>
+  Attrs ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    DataType cur_type = (attrs->dtype);
+    ObjectPtr<T> new_attrs = make_object<T>(*attrs);
+    if (cur_type.is_float() || cur_type.is_void()) new_attrs->dtype = accumulation_dtype;
+    return Attrs(new_attrs);
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsMixedPrecisionType(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only target mixed precision type elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == mixed_precision_type;
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsMixedPrecisionType(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    if (expr_dtype == wanted_dtype) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;
+
+    // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
+    const ExprNode* new_expr_node = result.as<ExprNode>();
+    cast_nodes_cache[{new_expr_node, expr_dtype}] = expr;
+    return result;
+  }
+
+  Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
+    /* Helper for casting arguments to call_nodes handling all relevant cases. */
+    if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+      return CachedCast(expr, tensor_type->dtype, wanted_dtype);
+    } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
+        Expr tuple_element = GetField(expr, i);
+        Type tuple_element_dtype = (tuple_type->fields)[i];
+        Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(tuple_element);
+      }
+      return all_same ? expr : Tuple(new_expr);
+    } else {
+      LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
+      return expr;
+    }
+  }
+
+  std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args,
+                                                  const Array<Type>& cur_arg_types,
+                                                  const DataType& wanted_dtype) {
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    for (size_t i = 0; i < cur_args.size(); i++) {
+      Expr cur_arg = cur_args[i];
+      Type cur_arg_type = cur_arg_types[i];
+      Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+    }
+    return {new_args, new_arg_types};
+  }
+
+ public:
+  using MixedModeMutator::VisitExpr_;
+
+  explicit MixedPrecisionPass(DataType mixed_precision_type = DataType::Float(16),
+                              bool ignore_missing_ops = true, bool warn_missing_ops = true)
+      : MixedModeMutator(),
+        mixed_precision_type(mixed_precision_type),
+        ignore_missing_ops(ignore_missing_ops),
+        warn_missing_ops(warn_missing_ops) {
+    if (!mixed_precision_type.is_float() && !mixed_precision_type.is_bfloat16())
+      LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat16 got "
+                 << mixed_precision_type;
+  }
+
+  Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final {
+    const CallNode* post_call_node = post.as<CallNode>();
+    if (!post_call_node) {
+      LOG(FATAL) << "Expected a CallNode for the rewrite got " << post;
+    }
+
+    Expr cur_op = post_call_node->op;
+
+    // Get info on the operation being called:
+    // conversion category (int), accumulation dtype (str), output dtype (str)
+    MixedTypeConversionCategory initial_category;
+    DataType accumulation_dtype, output_dtype;
+    if (cur_op.as<FunctionNode>()) {
+      // Avoid messing with functions to avoid changing signature
+      initial_category = MIXED_PRECISION_NEVER;
+      accumulation_dtype = DataType::Float(32);
+      output_dtype = DataType::Float(32);
+    } else if (cur_op.as<OpNode>()) {
+      static auto attr_map =
+          Op::GetAttrMap<FTVMMixedPrecisionConversionType>("FTVMMixedPrecisionConversionType");
+      Op op = Downcast<Op>(cur_op);
+      if (attr_map.count(op)) {
+        // Calculate the conversion category and dtypes from registered attribute.
+        FTVMMixedPrecisionConversionType func = attr_map[op];
+        Array<ObjectRef> op_descriptor =
+            func(GetRef<Call>(pre_call_node), DLDataType2String(mixed_precision_type));
+
+        int64_t op_conversion_type = Downcast<Integer>(op_descriptor[0])->value;
+        initial_category = static_cast<MixedTypeConversionCategory>(op_conversion_type);
+        accumulation_dtype = DataType(String2DLDataType(Downcast<String>(op_descriptor[1])));
+        output_dtype = DataType(String2DLDataType(Downcast<String>(op_descriptor[2])));
+      } else {
+        if (!ignore_missing_ops) LOG(FATAL) << "Op " << op->name << " not in conversion lists!";
+        if (warn_missing_ops) LOG(WARNING) << "Op " << op->name << " not in conversion lists!";
+
+        // If not registered, by default assume is a generic FOLLOW operation.
+        initial_category = MIXED_PRECISION_FOLLOW;
+        accumulation_dtype = DataType::Float(16);
+        output_dtype = DataType::Float(16);
+      }
+    } else {
+      LOG(FATAL) << "Unsupported op type in CallNode: " << pre_call_node->op;
+    }
+
+    // First check if all the new mutated args are in lower precision form
+    Array<Type> cur_arg_types;
+    bool all_args_mixed_type_compatible = true;
+    for (Expr arg : post_call_node->args) {
+      Type cur_arg_type = GetType(arg);
+      cur_arg_types.push_back(cur_arg_type);
+
+      if (initial_category == MIXED_PRECISION_FOLLOW && all_args_mixed_type_compatible) {
+        // We can cast Vars and Constants to the right types so don't care about the types.
+        bool is_mixed_type_compatible = IsMixedPrecisionType(cur_arg_type, true) ||
+                                        arg->IsInstance<VarNode>() ||
+                                        arg->IsInstance<ConstantNode>();
+        all_args_mixed_type_compatible &= is_mixed_type_compatible;
+      }
+    }
+
+    // Determine the final category we want for conversion
+    MixedTypeConversionCategory final_category;
+    if (initial_category == MIXED_PRECISION_FOLLOW) {
+      final_category =
+          all_args_mixed_type_compatible ? MIXED_PRECISION_ALWAYS : MIXED_PRECISION_NEVER;
+    } else {
+      final_category = initial_category;
+    }
+
+    // Create the new arguments to the call.
+    DataType wanted_arg_dtypes =
+        final_category == MIXED_PRECISION_ALWAYS ? mixed_precision_type : DataType::Float(32);
+    auto call_args_and_types = CastAllArgs(post_call_node->args, cur_arg_types, wanted_arg_dtypes);
+    Array<Expr> new_args = call_args_and_types.first;
+    Array<Type> new_arg_types;
+
+    if (pre_call_node->op.as<FunctionNode>()) {
+      // Function Nodes don't store type info in the Call, it should be a []
+      new_arg_types = pre_call_node->type_args;
+    } else {
+      new_arg_types = call_args_and_types.second;
+    }
+
+    // Finally create the new attributes.
+    if (final_category == MIXED_PRECISION_ALWAYS) {
+      Attrs new_attrs = GetNewAttrs(pre_call_node, accumulation_dtype);
+      Expr output = Call(cur_op, new_args, new_attrs, new_arg_types, pre_call_node->span);
+      if (accumulation_dtype != output_dtype) {
+        output = CastArg(output, GetType(output), output_dtype);
+      }
+      return output;
+    }
+
+    return Call(cur_op, new_args, pre_call_node->attrs, new_arg_types, pre_call_node->span);
+  }
+
+  Expr VisitExpr_(const FunctionNode* func) final {
+    // Erase the ret_type annotation and let the normal pass recalculate
+    const_cast<FunctionNode*>(func)->ret_type = Type(nullptr);
+    return ExprMutator::VisitExpr_(func);
+  }
+
+  Expr VisitExpr_(const LetNode* op) final {
+    // First convert as much of the bound computation to lower precision as possible
+    Expr value = this->Mutate(op->value);
+
+    // Then rewrite the var type and associated expression
+    Var var = Downcast<Var>(this->Mutate(op->var));
+    VarNode* mutable_var = const_cast<VarNode*>((op->var).as<VarNode>());
+    mutable_var->type_annotation = GetType(value);
+    mutable_var->checked_type_ = mutable_var->type_annotation;
+
+    // Mutate body last as it may depend on previous results
+    Expr body = this->Mutate(op->body);
+    return Let(var, value, body, op->span);
+  }
+};
+
+Expr ToMixedPrecision(const Expr& expr, const DataType& mixed_precision_type,
+                      bool ignore_missing_ops, bool warn_missing_ops) {
+  MixedPrecisionPass converter =
+      MixedPrecisionPass(mixed_precision_type, ignore_missing_ops, warn_missing_ops);
+  auto result = converter.Mutate(expr);
+  return result;
+}
+
+namespace transform {
+
+Pass ToMixedPrecision(DataType mixed_precision_type, bool ignore_missing_ops,
+                      bool warn_missing_ops) {
+  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
+      [=](Function f, IRModule m, PassContext pc) {
+        return Downcast<Function>(
+            ToMixedPrecision(f, mixed_precision_type, ignore_missing_ops, warn_missing_ops));
+      };
+  return CreateFunctionPass(pass_func, 10, "ToMixedPrecision", {});

Review comment:
       Not sure, I based this on the level for the fake_quantization_to_integer pass @mbrookhart wrote. I'm not too sure with the guideline behind this so any reference material would be welcome.
   
   I will change it to 0 if there are no objections.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] comaniac commented on pull request #8069: [Relay] [Pass] Add mixed precision (e.g. FP16) model conversion pass

Posted by GitBox <gi...@apache.org>.
comaniac commented on pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#issuecomment-1002919660


   Please do not ask questions in the PR directly.
   
   Weights have cast ops because they are parameters instead of constants. You have to bind parameters first, run ToMixPercision, and run FoldConstant to remove casts.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org