You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/06/09 18:25:40 UTC

[GitHub] [tvm] comaniac commented on a change in pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

comaniac commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r648541663



##########
File path: src/relay/transforms/fp32_to_fp16.cc
##########
@@ -0,0 +1,337 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file fp32_to_fp16.cc

Review comment:
       I would suggest using `amp.cc`/`amp.h` directly, as it should not be limited to FP32 to FP16.

##########
File path: src/relay/transforms/fp32_to_fp16.h
##########
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file fp32_to_fp16.h
+ * \brief Utilities and common types used for FP32->FP16 pass.
+ */
+#ifndef TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_
+#define TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_
+
+#include <tvm/ir/op.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/function.h>
+
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relay {
+
+struct FP16OpDType {
+  DataType accumulation_dtype;
+  DataType output_dtype;
+};
+
+// GREEN colored ops should always be done in FP16 due to the speed and memory savings
+// GRAY colored ops can be done in FP16 but don't have speedups to justify a dedicated cast.
+// RED colored ops should not be done in FP16 due to numerical reasons.
+enum FP16ConversionCategory { RED, GRAY, GREEN };
+
+using OpStringSet = std::unordered_set<std::string>;
+
+// Default lists inspired from TF's classifications:

Review comment:
       I don't prefer to specify op lists in a pass. It means we need to maintain this pass every time we add a new op. It would be better to follow the logic of other similar passes: Register an attribute to each op. If an op doesn't have this attribute registered, using the default behavior. It is also impossible for this implementation to accept user-defined rules from Python.

##########
File path: src/relay/transforms/fp32_to_fp16.h
##########
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file fp32_to_fp16.h
+ * \brief Utilities and common types used for FP32->FP16 pass.
+ */
+#ifndef TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_
+#define TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_
+
+#include <tvm/ir/op.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/function.h>
+
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relay {
+
+struct FP16OpDType {
+  DataType accumulation_dtype;
+  DataType output_dtype;
+};
+
+// GREEN colored ops should always be done in FP16 due to the speed and memory savings
+// GRAY colored ops can be done in FP16 but don't have speedups to justify a dedicated cast.
+// RED colored ops should not be done in FP16 due to numerical reasons.
+enum FP16ConversionCategory { RED, GRAY, GREEN };

Review comment:
       Some suggestions:
   1. There are more like an attribute instead of category.
   2. Use straightforward terms, such as  ALWAYS/FOLLOW/NEVER, instead of RED/GRAY/GREEN.
   3. Emit warning for non-specified ops may result in tedious messages. We could make it configurable to let users decide whether to print out these ops.

##########
File path: src/relay/transforms/fp32_to_fp16.cc
##########
@@ -0,0 +1,337 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file fp32_to_fp16.cc
+ * \brief Rewrite a graph into an fp16 form.
+ */
+#include "fp32_to_fp16.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// A function which maps CallNodes to their initial conversion color
+using ColorFunc = std::function<FP16ConversionCategory(const CallNode*)>;
+
+// A function which maps green CallNodes to wanted accumulation and output dtypes
+using OutputDtypeFunc = std::function<FP16OpDType(const CallNode*)>;
+
+class AmpGraphCreator : public ExprMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+  const ColorFunc colorer;
+  const OutputDtypeFunc output_dtype_func;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs new_attrs = Attrs(call->attrs);
+    if (new_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = new_attrs.as<Conv1DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv1DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DeformableConv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DenseAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<BatchMatmulAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = new_attrs.as<InitOpAttrs>()) {
+        ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return new_attrs;
+  }
+
+  template <typename T>
+  void ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    T* mutable_attrs = const_cast<T*>(attrs);
+
+    DataType cur_type = (mutable_attrs->out_dtype);
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->out_dtype = accumulation_dtype;
+  }
+
+  template <typename T>
+  void ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    T* mutable_attrs = const_cast<T*>(attrs);
+    DataType cur_type = (mutable_attrs->dtype);
+
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->dtype = accumulation_dtype;
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsFP16Type(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only fp16 elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == DataType::Float(16);
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsFP16Type(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = expr_dtype == wanted_dtype ? expr : Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;
+
+    // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
+    const ExprNode* new_expr_node = result.as<ExprNode>();
+    cast_nodes_cache[{new_expr_node, expr_dtype}] = expr;
+    return result;
+  }
+
+  Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
+    /* Helper for casting arguments to call_nodes handling all relevant cases. */
+    if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+      return CachedCast(expr, tensor_type->dtype, wanted_dtype);
+    } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
+        Expr tuple_element = GetField(expr, i);
+        Type tuple_element_dtype = (tuple_type->fields)[i];
+        Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(tuple_element);
+      }
+      return all_same ? expr : Tuple(new_expr);
+    } else {
+      LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
+      return expr;
+    }
+  }
+
+  std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args,
+                                                  const Array<Type>& cur_arg_types,
+                                                  const DataType& wanted_dtype) {
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    for (size_t i = 0; i < cur_args.size(); i++) {
+      Expr cur_arg = cur_args[i];
+      Type cur_arg_type = cur_arg_types[i];
+      Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+    }
+    return {new_args, new_arg_types};
+  }
+
+ public:
+  explicit AmpGraphCreator(ColorFunc colorer, OutputDtypeFunc output_dtype_func)
+      : ExprMutator(), colorer(colorer), output_dtype_func(output_dtype_func) {}
+
+  Expr VisitExpr_(const CallNode* call_node) {
+    FP16ConversionCategory initial_color = colorer(call_node);
+    auto new_op = this->Mutate(call_node->op);
+
+    // Mutate arguments to FP16 form first if possible and keep track of whether all floating point
+    // tensors are in FP16 form already. This is useful for propagating color.
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    bool all_args_fp16_compatible = true;
+    for (Expr arg : call_node->args) {
+      Expr new_arg = this->Mutate(arg);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+
+      if (all_args_fp16_compatible) {
+        // We can cast Vars and Constants to the right types so don't care about the types.
+        bool is_fp16_compatible = IsFP16Type(new_arg_type, true) || arg->IsInstance<VarNode>() ||
+                                  arg->IsInstance<ConstantNode>();
+        all_args_fp16_compatible &= is_fp16_compatible;
+      }
+    }
+
+    // Determine the final color.
+    FP16ConversionCategory final_color;
+    if (initial_color == GRAY) {
+      final_color = all_args_fp16_compatible ? GREEN : RED;
+    } else {
+      final_color = initial_color;
+    }
+
+    // Create the new arguments to the call.
+    DataType wanted_arg_dtypes = final_color == GREEN ? DataType::Float(16) : DataType::Float(32);
+    auto call_args_and_types = CastAllArgs(new_args, new_arg_types, wanted_arg_dtypes);
+
+    Array<Expr> call_args = call_args_and_types.first;
+    Array<Type> call_arg_types;
+
+    if (call_node->op.as<FunctionNode>()) {
+      // Function Nodes don't store type info in the Call, it should be a []
+      call_arg_types = call_node->type_args;
+    } else {
+      call_arg_types = call_args_and_types.second;
+    }
+
+    // Finally create the new attributes.
+    if (final_color == GREEN) {
+      FP16OpDType output_dtypes = output_dtype_func(call_node);
+
+      Attrs new_attrs = GetNewAttrs(call_node, output_dtypes.accumulation_dtype);
+      Expr output = Call(new_op, call_args, new_attrs, call_arg_types, call_node->span);
+      if (output_dtypes.accumulation_dtype != output_dtypes.output_dtype) {
+        output = CastArg(output, GetType(output), output_dtypes.output_dtype);

Review comment:
       Wouldn't this introduce unnecessary cast ops? For example, the accumulation dtype is FP32 and the followed op is RED. Will this make it `A(GREEN) - cast_to_fp16 - cast_to_fp32 - B(RED)`?

##########
File path: src/relay/transforms/fp32_to_fp16.cc
##########
@@ -0,0 +1,337 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file fp32_to_fp16.cc
+ * \brief Rewrite a graph into an fp16 form.
+ */
+#include "fp32_to_fp16.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// A function which maps CallNodes to their initial conversion color
+using ColorFunc = std::function<FP16ConversionCategory(const CallNode*)>;
+
+// A function which maps green CallNodes to wanted accumulation and output dtypes
+using OutputDtypeFunc = std::function<FP16OpDType(const CallNode*)>;
+
+class AmpGraphCreator : public ExprMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+  const ColorFunc colorer;
+  const OutputDtypeFunc output_dtype_func;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs new_attrs = Attrs(call->attrs);
+    if (new_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = new_attrs.as<Conv1DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv1DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DeformableConv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DenseAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<BatchMatmulAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = new_attrs.as<InitOpAttrs>()) {
+        ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return new_attrs;
+  }
+
+  template <typename T>
+  void ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    T* mutable_attrs = const_cast<T*>(attrs);
+
+    DataType cur_type = (mutable_attrs->out_dtype);
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->out_dtype = accumulation_dtype;
+  }
+
+  template <typename T>
+  void ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    T* mutable_attrs = const_cast<T*>(attrs);
+    DataType cur_type = (mutable_attrs->dtype);
+
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->dtype = accumulation_dtype;
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsFP16Type(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only fp16 elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == DataType::Float(16);
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsFP16Type(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = expr_dtype == wanted_dtype ? expr : Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;
+
+    // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
+    const ExprNode* new_expr_node = result.as<ExprNode>();
+    cast_nodes_cache[{new_expr_node, expr_dtype}] = expr;
+    return result;
+  }
+
+  Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
+    /* Helper for casting arguments to call_nodes handling all relevant cases. */
+    if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+      return CachedCast(expr, tensor_type->dtype, wanted_dtype);
+    } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
+        Expr tuple_element = GetField(expr, i);
+        Type tuple_element_dtype = (tuple_type->fields)[i];
+        Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(tuple_element);
+      }
+      return all_same ? expr : Tuple(new_expr);
+    } else {
+      LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
+      return expr;
+    }
+  }
+
+  std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args,
+                                                  const Array<Type>& cur_arg_types,
+                                                  const DataType& wanted_dtype) {
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    for (size_t i = 0; i < cur_args.size(); i++) {
+      Expr cur_arg = cur_args[i];
+      Type cur_arg_type = cur_arg_types[i];
+      Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+    }
+    return {new_args, new_arg_types};
+  }
+
+ public:
+  explicit AmpGraphCreator(ColorFunc colorer, OutputDtypeFunc output_dtype_func)
+      : ExprMutator(), colorer(colorer), output_dtype_func(output_dtype_func) {}
+
+  Expr VisitExpr_(const CallNode* call_node) {
+    FP16ConversionCategory initial_color = colorer(call_node);
+    auto new_op = this->Mutate(call_node->op);
+
+    // Mutate arguments to FP16 form first if possible and keep track of whether all floating point
+    // tensors are in FP16 form already. This is useful for propagating color.
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    bool all_args_fp16_compatible = true;
+    for (Expr arg : call_node->args) {
+      Expr new_arg = this->Mutate(arg);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+
+      if (all_args_fp16_compatible) {
+        // We can cast Vars and Constants to the right types so don't care about the types.
+        bool is_fp16_compatible = IsFP16Type(new_arg_type, true) || arg->IsInstance<VarNode>() ||
+                                  arg->IsInstance<ConstantNode>();
+        all_args_fp16_compatible &= is_fp16_compatible;
+      }
+    }

Review comment:
       You don't need to perform the entire logic when initial_color is not GRAY.

##########
File path: python/tvm/relay/transform/transform.py
##########
@@ -1199,3 +1198,20 @@ def FakeQuantizationToInteger():
         The registered SimplifyExpr pass.
     """
     return _ffi_api.FakeQuantizationToInteger()
+
+
+def AMPRewrite():

Review comment:
       Since this is the user API, we might need to think a bit more to make it more straightforward. For example, `AMP` or `AutoCast` are better naming IMHO.

##########
File path: src/relay/transforms/fp32_to_fp16.cc
##########
@@ -0,0 +1,337 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file fp32_to_fp16.cc
+ * \brief Rewrite a graph into an fp16 form.
+ */
+#include "fp32_to_fp16.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// A function which maps CallNodes to their initial conversion color
+using ColorFunc = std::function<FP16ConversionCategory(const CallNode*)>;
+
+// A function which maps green CallNodes to wanted accumulation and output dtypes
+using OutputDtypeFunc = std::function<FP16OpDType(const CallNode*)>;
+
+class AmpGraphCreator : public ExprMutator {

Review comment:
       Please use the iterative implementation instead of recursion to avoid stack overflow.

##########
File path: src/relay/transforms/fp32_to_fp16.h
##########
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file fp32_to_fp16.h
+ * \brief Utilities and common types used for FP32->FP16 pass.
+ */
+#ifndef TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_
+#define TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_
+
+#include <tvm/ir/op.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/function.h>
+
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relay {
+
+struct FP16OpDType {
+  DataType accumulation_dtype;
+  DataType output_dtype;
+};
+
+// GREEN colored ops should always be done in FP16 due to the speed and memory savings
+// GRAY colored ops can be done in FP16 but don't have speedups to justify a dedicated cast.
+// RED colored ops should not be done in FP16 due to numerical reasons.
+enum FP16ConversionCategory { RED, GRAY, GREEN };
+
+using OpStringSet = std::unordered_set<std::string>;
+
+// Default lists inspired from TF's classifications:
+// github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h
+// They have a bias toward Nvidia Tensor Cores so modify lists per your hardware choice.
+OpStringSet DEFAULT_GREEN_LIST({
+    "nn.conv1d",
+    "nn.conv2d",
+    "nn.conv3d",
+    "nn.conv1d_transpose",
+    "nn.conv2d_transpose",
+    "nn.conv3d_transpose",
+    "nn.dense",
+    "nn.batch_matmul",
+});
+OpStringSet DEFAULT_GRAY_LIST({
+    // These ops add new data or change shape
+    "nn.pad",
+    "nn.batch_flatten",
+    "concatenate",
+    "zeros",
+    "split",
+    "squeeze",
+    "transpose",
+    "expand_dims",
+    "reshape",
+    "dyn.reshape",
+    "broadcast_to_like",
+    "dyn.broadcast_to",
+    "strided_slice",
+    "dyn.strided_slice",
+    "take",
+    "argwhere",
+    "where",
+    "tile",
+    "dyn.tile",
+    "scatter",
+    "full",
+    "dyn.full",
+    // Comparison
+    "less",
+    "greater",
+    "less_equal",
+    "greater_equal",
+    // By definition copy and cast will become green or red based on inputs
+    "copy",
+    "cast",
+    "cast_like",
+    // Simple arithmetic
+    "add",
+    "subtract",
+    "multiply",
+    "divide",
+    "nn.bias_add",
+    "nn.batch_norm",
+    "sum",
+    "mean",
+    "sqrt",
+    "shape_of",
+    // Simple activations
+    "max",
+    "min",
+    "maximum",
+    "minimum",
+    "nn.relu",
+    "nn.leaky_relu",
+    "nn.prelu",
+    "nn.dropout",
+    // Complicated activations which saturate in a narrow range
+    "sigmoid",
+    "tanh",
+    // Pooling operations
+    "nn.max_pool1d",
+    "nn.max_pool2d",
+    "nn.max_pool3d",
+    "nn.avg_pool1d",
+    "nn.avg_pool2d",
+    "nn.avg_pool3d",
+    // "nn.global_max_pool1d", // does not exist yet
+    "nn.global_max_pool2d",
+    // "nn.global_max_pool3d", // does not exist yet
+    // "nn.global_avg_pool1d", // does not exist yet
+    "nn.global_avg_pool2d",
+    // "nn.global_avg_pool3d", // does not exist yet
+    "nn.adaptive_max_pool1d",
+    "nn.adaptive_max_pool2d",
+    "nn.adaptive_max_pool3d",
+    "nn.adaptive_avg_pool1d",
+    "nn.adaptive_avg_pool2d",
+    "nn.adaptive_avg_pool3d",
+});
+OpStringSet DEFAULT_RED_LIST({
+    // In general if |f(x)| >> |x| for expected inputs then put the op here.
+    "exp",
+    "power",
+    "nn.cross_entropy",
+    "nn.cross_entropy_with_logits",
+    "nn.softmax",
+    "nn.l2_normalize",
+    // Error function doesn't seem to be able to be lowered into fp16 version in llvm.
+    // Move to gray list when it does.
+    "erf",
+});
+
+class DefaultFP16Colorer {
+  /* The default class to initially color ops for conversion using lists.
+
+  Creates a callable which given a CallNode* returns the node's color.
+  */

Review comment:
       Please use the formal docstring format.

##########
File path: tests/python/relay/test_fp32_to_fp16_transform.py
##########
@@ -0,0 +1,328 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for testing FP32 -> FP16 pass"""
+from typing import Any, Dict, List
+
+import numpy as np
+import tvm
+from tvm import relay
+from tvm.relay.testing import lstm
+from tvm.relay.transform import AMPRewrite
+from tvm.relay.transform.transform import InferType
+
+
+def run_module(mod: tvm.runtime.Module, mod_params: Dict[str, Any]) -> List:
+    dev = tvm.device("llvm", 0)
+    intrp = relay.create_executor("debug", mod, device=dev, target="llvm")
+    result = intrp.evaluate()(**mod_params)
+    if isinstance(result, tvm.runtime.container.ADT):
+        result = [r.asnumpy() for r in result]
+        return result
+    else:
+        return [result.asnumpy()]
+
+
+def verify_fp32_fp16_output_close(
+    mod: tvm.runtime.Module, mod_params: Dict[str, Any], rtol: float = 1e-3, atol: float = 0
+) -> tvm.runtime.Module:
+    mod = InferType()(mod)
+    result_fp32 = run_module(mod, mod_params)
+    fp16_mod = AMPRewrite()(mod)
+    result_fp16 = run_module(fp16_mod, mod_params)
+
+    # Ensure the results are close
+    for fp32, fp16 in zip(result_fp32, result_fp16):
+        np.testing.assert_allclose(fp32, fp16, rtol=rtol, atol=atol)
+
+    return fp16_mod
+
+
+def test_lstm():
+    """A small stress test on a single unrolled lstm unit.
+
+    Has internal functions and let statements the pass must work on.
+    """
+    np.random.seed(5628)
+    units = 3
+    iterations = 5
+    mod, mod_params = lstm.get_workload(iterations=iterations, num_hidden=units)
+
+    # This is an unrolled lstm so each data should be the previous results but
+    # we don't care, we just want to stress test things.
+    for i in range(iterations):
+        mod_params["data" if i == 0 else f"data{i}"] = np.random.uniform(
+            -10, 10, (1, units)
+        ).astype("float32")
+
+    verify_fp32_fp16_output_close(mod, mod_params, rtol=0.01, atol=0.01)
+
+
+def test_convert_single_conv():
+    """Conv is a green listed operation meaning it will always use fp16 workload.
+
+    By default it accumulates to fp32 and outputs fp16.
+    """
+    np.random.seed(208)
+
+    data_shape = (1, 3, 32, 32)
+    weight_shape = (5, 3, 3, 3)
+    data = relay.var("data", shape=data_shape, dtype="float32")
+    weight = relay.var("weight", shape=weight_shape, dtype="float32")
+    conv = relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float32")
+    mod = tvm.IRModule.from_expr(conv)
+    mod = tvm.relay.transform.InferType()(mod)
+
+    mod_params = {
+        "data": np.random.uniform(-1, 1, size=data_shape).astype("float32"),
+        "weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"),
+    }
+    fp16_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.01, rtol=1e-3)
+
+    expected_mod = tvm.IRModule.from_expr(
+        relay.cast(
+            relay.nn.conv2d(
+                relay.cast(data, "float16"),
+                relay.cast(weight, "float16"),
+                strides=(1, 1),
+                padding=(1, 1),
+                out_dtype="float32",
+            ),
+            "float16",
+        )
+    )
+    expected_mod = tvm.relay.transform.InferType()(expected_mod)
+
+    assert not tvm.ir.structural_equal(fp16_mod, mod)
+    assert tvm.ir.structural_equal(fp16_mod, expected_mod)
+
+
+def test_convert_conv_bn():
+    """Conv is green and batch norm is gray. As Conv should output fp16 batch_norm should be green."""
+    np.random.seed(208)
+
+    data_shape = (1, 3, 32, 32)
+    weight_shape = (5, 3, 3, 3)
+    data = relay.var("data", shape=data_shape, dtype="float32")
+    weight = relay.var("weight", shape=weight_shape, dtype="float32")
+    conv = relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float32")
+
+    bn_shape = [5]
+    gamma = relay.var("gamma", shape=bn_shape)
+    beta = relay.var("beta", shape=bn_shape)
+    moving_mean = relay.var("moving_mean", shape=bn_shape)
+    moving_var = relay.var("moving_var", shape=bn_shape)
+    bn = relay.nn.batch_norm(conv, gamma, beta, moving_mean, moving_var)
+    mod = tvm.IRModule.from_expr(bn[0])
+    mod = tvm.relay.transform.InferType()(mod)
+
+    mod_params = {
+        "data": np.random.uniform(-1, 1, size=data_shape).astype("float32"),
+        "weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"),
+        "gamma": np.random.uniform(-1, 1, size=bn_shape).astype("float32"),
+        "beta": np.random.uniform(-1, 1, size=bn_shape).astype("float32"),
+        "moving_mean": np.random.uniform(-1, 1, size=bn_shape).astype("float32"),
+        "moving_var": np.random.uniform(-1, 1, size=bn_shape).astype("float32"),
+    }
+    fp16_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.01, rtol=1e-3)
+
+    # Creating expected module
+    data = relay.cast(relay.var("data", shape=data_shape), "float16")
+    weight = relay.cast(relay.var("weight", shape=weight_shape), "float16")
+    conv = relay.cast(
+        relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float32"),
+        "float16",
+    )
+
+    bn_shape = [5]
+    gamma = relay.cast(relay.var("gamma", shape=bn_shape), "float16")
+    beta = relay.cast(relay.var("beta", shape=bn_shape), "float16")
+    moving_mean = relay.cast(relay.var("moving_mean", shape=bn_shape), "float16")
+    moving_var = relay.cast(relay.var("moving_var", shape=bn_shape), "float16")
+    bn = relay.nn.batch_norm(conv, gamma, beta, moving_mean, moving_var)
+
+    expected_mod = tvm.IRModule.from_expr(bn[0])
+    expected_mod = tvm.relay.transform.InferType()(expected_mod)
+    assert not tvm.ir.structural_equal(fp16_mod, mod)
+    assert tvm.ir.structural_equal(fp16_mod, expected_mod)
+
+
+def test_do_not_convert_softmax():
+    """Softmax is a red listed operation and therefore should never be fp16."""
+    np.random.seed(209)
+    shape = [1, 2, 3]
+    a = relay.var("a", shape=shape)
+    b = relay.nn.softmax(a)
+    mod = tvm.IRModule.from_expr(b)
+    mod = tvm.relay.transform.InferType()(mod)
+
+    mod_params = {
+        "a": np.random.uniform(-1, 1, size=shape).astype("float32"),
+    }
+    output_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.0, rtol=0)
+    assert tvm.ir.structural_equal(mod, output_mod)
+
+
+def test_green_gray_propagates_simple():
+    """Conv is a green listed operation, while addition is gray.
+
+    As Conv outputs fp16 the add should be done in fp16.
+    """
+    np.random.seed(210)
+    data_shape = (1, 3, 32, 32)
+    weight_shape = (5, 3, 3, 3)
+    data = relay.var("data", shape=data_shape, dtype="float32")
+    weight = relay.var("weight", shape=weight_shape, dtype="float32")
+    conv = relay.nn.conv2d(data, weight, strides=(1, 1), padding=(1, 1), out_dtype="float32")
+    conv = conv + conv
+    mod = tvm.IRModule.from_expr(conv)
+    mod = tvm.relay.transform.InferType()(mod)
+
+    mod_params = {
+        "data": np.random.uniform(-1, 1, size=data_shape).astype("float32"),
+        "weight": np.random.uniform(-1, 1, size=weight_shape).astype("float32"),
+    }
+    fp16_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.01, rtol=1e-3)
+
+    conv_expr = relay.cast(
+        relay.nn.conv2d(
+            relay.cast(data, "float16"),
+            relay.cast(weight, "float16"),
+            strides=(1, 1),
+            padding=(1, 1),
+            out_dtype="float32",
+        ),
+        "float16",
+    )
+    expected_mod = tvm.IRModule.from_expr(conv_expr + conv_expr)
+    expected_mod = tvm.relay.transform.InferType()(expected_mod)
+
+    assert not tvm.ir.structural_equal(fp16_mod, mod)
+    assert tvm.ir.structural_equal(fp16_mod, expected_mod)
+
+
+def test_red_gray_propagates_simple():
+    """Everything after a softmax should be in FP32 (exception green colored ops)"""
+    np.random.seed(211)
+    shape = [1, 2, 3]
+    a = relay.var("a", shape=shape)
+    b = relay.nn.softmax(a)
+    c = b + b
+    mod = tvm.IRModule.from_expr(c)
+    mod = tvm.relay.transform.InferType()(mod)
+
+    mod_params = {
+        "a": np.random.uniform(-1, 1, size=shape).astype("float32"),
+    }
+    output_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.0, rtol=0.0)
+
+    assert tvm.ir.structural_equal(mod, output_mod)
+
+
+def test_let_statement_simple():
+    """A 'simple' let statement example.
+
+    Noticable is the mutation of the bound variable types.
+    """
+    np.random.seed(211)
+    var1 = relay.var("var1", shape=[1, 20])
+    var2 = relay.var("var2", shape=[1, 20])
+
+    data = relay.var("data", shape=[1, 20])
+    weight = relay.var("weight", shape=[20, 20])
+
+    r1 = var1 + var1
+
+    r2 = var2 + var2
+    let2 = relay.Let(var2, relay.nn.dense(r1, weight, units=20), r2)
+    let1 = relay.Let(var1, relay.nn.dense(data, weight, units=20), let2)
+
+    mod = tvm.IRModule.from_expr(let1)
+    mod_params = {
+        "data": np.random.uniform(-1, 1, size=[1, 20]).astype("float32"),
+        "weight": np.random.uniform(-1, 1, size=[20, 20]).astype("float32"),
+    }
+    output_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.01, rtol=0.01)
+
+    # Construct expected structure
+    var1 = relay.var("var1", shape=[1, 20], dtype="float16")
+    var2 = relay.var("var2", shape=[1, 20], dtype="float16")
+    data = relay.cast(relay.var("data", shape=[1, 20]), "float16")
+    weight = relay.cast(relay.var("weight", shape=[20, 20]), "float16")
+    r1 = var1 + var1
+    r2 = var2 + var2
+    let2 = relay.Let(
+        var2,
+        relay.cast(relay.nn.dense(r1, weight, units=20, out_dtype="float32"), "float16"),
+        r2,
+    )
+    let1 = relay.Let(
+        var1,
+        relay.cast(relay.nn.dense(data, weight, units=20, out_dtype="float32"), "float16"),
+        let2,
+    )
+    expected_mod = tvm.IRModule.from_expr(let1)
+    expected_mod = InferType()(expected_mod)
+
+    assert tvm.ir.structural_equal(expected_mod, output_mod)
+
+
+def test_where_simple():
+    data = relay.var("data", shape=[1, 20])
+    weight = relay.var("weight", shape=[20, 20])
+    a = relay.nn.dense(data, weight, units=20)
+    b = relay.where(data, a, a)
+    mod = tvm.IRModule.from_expr(b)
+    mod_params = {
+        "data": np.random.uniform(-1, 1, size=[1, 20]).astype("float32"),
+        "weight": np.random.uniform(-1, 1, size=[20, 20]).astype("float32"),
+    }
+
+    output_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.01, rtol=0.01)
+
+    # Create expected module
+    data = relay.cast(relay.var("data", shape=[1, 20]), "float16")
+    weight = relay.cast(relay.var("weight", shape=[20, 20]), "float16")
+    a = relay.cast(relay.nn.dense(data, weight, units=20, out_dtype="float32"), "float16")
+    b = relay.where(data, a, a)
+    expected_mod = tvm.IRModule.from_expr(b)
+    expected_mod = InferType()(expected_mod)
+
+    assert tvm.ir.structural_equal(expected_mod, output_mod)
+
+
+def test_batch_matmul_simple():
+    """Batch matmul is a special case where we try to accumulate to fp16.
+
+    This is due to the fact heterogenous accumulation dtypes does not work
+    on all platforms at the moment.
+    """
+    data = relay.var("data", shape=[1, 1, 20])
+    weight = relay.var("weight", shape=[1, 20, 20])
+    a = relay.nn.batch_matmul(data, weight)
+    mod = tvm.IRModule.from_expr(a)
+    mod_params = {
+        "data": np.random.uniform(-1, 1, size=[1, 1, 20]).astype("float32"),
+        "weight": np.random.uniform(-1, 1, size=[1, 20, 20]).astype("float32"),
+    }
+    output_mod = verify_fp32_fp16_output_close(mod, mod_params, atol=0.01, rtol=0.01)
+    # Create expected module
+    data = relay.cast(relay.var("data", shape=[1, 1, 20]), "float16")
+    weight = relay.cast(relay.var("weight", shape=[1, 20, 20]), "float16")
+    a = relay.nn.batch_matmul(data, weight, out_dtype="float16")
+    expected_mod = tvm.IRModule.from_expr(a)
+    expected_mod = InferType()(expected_mod)
+    assert tvm.ir.structural_equal(expected_mod, output_mod)

Review comment:
       Add the following to the end.
   
   ```python
   if __name__ == "__main__":
       pytest.main([__file__])
   ```

##########
File path: src/relay/transforms/fp32_to_fp16.cc
##########
@@ -0,0 +1,337 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file fp32_to_fp16.cc
+ * \brief Rewrite a graph into an fp16 form.
+ */
+#include "fp32_to_fp16.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// A function which maps CallNodes to their initial conversion color
+using ColorFunc = std::function<FP16ConversionCategory(const CallNode*)>;
+
+// A function which maps green CallNodes to wanted accumulation and output dtypes
+using OutputDtypeFunc = std::function<FP16OpDType(const CallNode*)>;
+
+class AmpGraphCreator : public ExprMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+  const ColorFunc colorer;
+  const OutputDtypeFunc output_dtype_func;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs new_attrs = Attrs(call->attrs);
+    if (new_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = new_attrs.as<Conv1DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv1DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DeformableConv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DenseAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<BatchMatmulAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = new_attrs.as<InitOpAttrs>()) {
+        ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return new_attrs;
+  }
+
+  template <typename T>
+  void ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    T* mutable_attrs = const_cast<T*>(attrs);
+
+    DataType cur_type = (mutable_attrs->out_dtype);
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->out_dtype = accumulation_dtype;
+  }
+
+  template <typename T>
+  void ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    T* mutable_attrs = const_cast<T*>(attrs);
+    DataType cur_type = (mutable_attrs->dtype);
+
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->dtype = accumulation_dtype;
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsFP16Type(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only fp16 elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == DataType::Float(16);
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsFP16Type(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = expr_dtype == wanted_dtype ? expr : Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;
+
+    // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
+    const ExprNode* new_expr_node = result.as<ExprNode>();
+    cast_nodes_cache[{new_expr_node, expr_dtype}] = expr;
+    return result;
+  }
+
+  Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
+    /* Helper for casting arguments to call_nodes handling all relevant cases. */
+    if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+      return CachedCast(expr, tensor_type->dtype, wanted_dtype);
+    } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
+        Expr tuple_element = GetField(expr, i);
+        Type tuple_element_dtype = (tuple_type->fields)[i];
+        Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(tuple_element);
+      }
+      return all_same ? expr : Tuple(new_expr);
+    } else {
+      LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
+      return expr;
+    }
+  }
+
+  std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args,
+                                                  const Array<Type>& cur_arg_types,
+                                                  const DataType& wanted_dtype) {
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    for (size_t i = 0; i < cur_args.size(); i++) {
+      Expr cur_arg = cur_args[i];
+      Type cur_arg_type = cur_arg_types[i];
+      Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+    }
+    return {new_args, new_arg_types};
+  }
+
+ public:
+  explicit AmpGraphCreator(ColorFunc colorer, OutputDtypeFunc output_dtype_func)
+      : ExprMutator(), colorer(colorer), output_dtype_func(output_dtype_func) {}
+
+  Expr VisitExpr_(const CallNode* call_node) {
+    FP16ConversionCategory initial_color = colorer(call_node);
+    auto new_op = this->Mutate(call_node->op);
+
+    // Mutate arguments to FP16 form first if possible and keep track of whether all floating point
+    // tensors are in FP16 form already. This is useful for propagating color.
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    bool all_args_fp16_compatible = true;
+    for (Expr arg : call_node->args) {
+      Expr new_arg = this->Mutate(arg);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+
+      if (all_args_fp16_compatible) {
+        // We can cast Vars and Constants to the right types so don't care about the types.
+        bool is_fp16_compatible = IsFP16Type(new_arg_type, true) || arg->IsInstance<VarNode>() ||
+                                  arg->IsInstance<ConstantNode>();
+        all_args_fp16_compatible &= is_fp16_compatible;
+      }
+    }
+
+    // Determine the final color.
+    FP16ConversionCategory final_color;
+    if (initial_color == GRAY) {
+      final_color = all_args_fp16_compatible ? GREEN : RED;
+    } else {
+      final_color = initial_color;
+    }
+
+    // Create the new arguments to the call.
+    DataType wanted_arg_dtypes = final_color == GREEN ? DataType::Float(16) : DataType::Float(32);
+    auto call_args_and_types = CastAllArgs(new_args, new_arg_types, wanted_arg_dtypes);
+
+    Array<Expr> call_args = call_args_and_types.first;
+    Array<Type> call_arg_types;
+
+    if (call_node->op.as<FunctionNode>()) {
+      // Function Nodes don't store type info in the Call, it should be a []
+      call_arg_types = call_node->type_args;
+    } else {
+      call_arg_types = call_args_and_types.second;
+    }
+
+    // Finally create the new attributes.
+    if (final_color == GREEN) {
+      FP16OpDType output_dtypes = output_dtype_func(call_node);
+
+      Attrs new_attrs = GetNewAttrs(call_node, output_dtypes.accumulation_dtype);
+      Expr output = Call(new_op, call_args, new_attrs, call_arg_types, call_node->span);
+      if (output_dtypes.accumulation_dtype != output_dtypes.output_dtype) {
+        output = CastArg(output, GetType(output), output_dtypes.output_dtype);
+      }
+      return output;
+    } else {
+      return Call(new_op, call_args, call_node->attrs, call_arg_types, call_node->span);
+    }

Review comment:
       ```suggestion
       }
       return Call(new_op, call_args, call_node->attrs, call_arg_types, call_node->span);
   ```

##########
File path: src/relay/transforms/fp32_to_fp16.cc
##########
@@ -0,0 +1,337 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file fp32_to_fp16.cc
+ * \brief Rewrite a graph into an fp16 form.
+ */
+#include "fp32_to_fp16.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    // Use boost's combine_hash strategy
+    return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2));
+  }
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// A function which maps CallNodes to their initial conversion color
+using ColorFunc = std::function<FP16ConversionCategory(const CallNode*)>;
+
+// A function which maps green CallNodes to wanted accumulation and output dtypes
+using OutputDtypeFunc = std::function<FP16OpDType(const CallNode*)>;
+
+class AmpGraphCreator : public ExprMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+  const ColorFunc colorer;
+  const OutputDtypeFunc output_dtype_func;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs new_attrs = Attrs(call->attrs);
+    if (new_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = new_attrs.as<Conv1DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv1DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DeformableConv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DenseAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<BatchMatmulAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = new_attrs.as<InitOpAttrs>()) {
+        ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return new_attrs;
+  }
+
+  template <typename T>
+  void ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    T* mutable_attrs = const_cast<T*>(attrs);
+
+    DataType cur_type = (mutable_attrs->out_dtype);
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->out_dtype = accumulation_dtype;
+  }
+
+  template <typename T>
+  void ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    T* mutable_attrs = const_cast<T*>(attrs);
+    DataType cur_type = (mutable_attrs->dtype);
+
+    if (cur_type.is_float() || cur_type.is_void()) mutable_attrs->dtype = accumulation_dtype;
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsFP16Type(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only fp16 elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == DataType::Float(16);
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsFP16Type(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = expr_dtype == wanted_dtype ? expr : Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;
+
+    // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
+    const ExprNode* new_expr_node = result.as<ExprNode>();
+    cast_nodes_cache[{new_expr_node, expr_dtype}] = expr;
+    return result;
+  }
+
+  Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
+    /* Helper for casting arguments to call_nodes handling all relevant cases. */
+    if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+      return CachedCast(expr, tensor_type->dtype, wanted_dtype);
+    } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
+        Expr tuple_element = GetField(expr, i);
+        Type tuple_element_dtype = (tuple_type->fields)[i];
+        Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(tuple_element);
+      }
+      return all_same ? expr : Tuple(new_expr);
+    } else {
+      LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
+      return expr;
+    }
+  }
+
+  std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args,
+                                                  const Array<Type>& cur_arg_types,
+                                                  const DataType& wanted_dtype) {
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    for (size_t i = 0; i < cur_args.size(); i++) {
+      Expr cur_arg = cur_args[i];
+      Type cur_arg_type = cur_arg_types[i];
+      Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+    }
+    return {new_args, new_arg_types};
+  }
+
+ public:
+  explicit AmpGraphCreator(ColorFunc colorer, OutputDtypeFunc output_dtype_func)
+      : ExprMutator(), colorer(colorer), output_dtype_func(output_dtype_func) {}
+
+  Expr VisitExpr_(const CallNode* call_node) {
+    FP16ConversionCategory initial_color = colorer(call_node);
+    auto new_op = this->Mutate(call_node->op);
+
+    // Mutate arguments to FP16 form first if possible and keep track of whether all floating point
+    // tensors are in FP16 form already. This is useful for propagating color.
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    bool all_args_fp16_compatible = true;
+    for (Expr arg : call_node->args) {
+      Expr new_arg = this->Mutate(arg);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+
+      if (all_args_fp16_compatible) {
+        // We can cast Vars and Constants to the right types so don't care about the types.
+        bool is_fp16_compatible = IsFP16Type(new_arg_type, true) || arg->IsInstance<VarNode>() ||
+                                  arg->IsInstance<ConstantNode>();
+        all_args_fp16_compatible &= is_fp16_compatible;
+      }
+    }
+
+    // Determine the final color.
+    FP16ConversionCategory final_color;
+    if (initial_color == GRAY) {
+      final_color = all_args_fp16_compatible ? GREEN : RED;

Review comment:
       Can you provide an example of FP16 incompatible?

##########
File path: src/relay/transforms/fp32_to_fp16.h
##########
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file fp32_to_fp16.h
+ * \brief Utilities and common types used for FP32->FP16 pass.
+ */
+#ifndef TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_
+#define TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_
+
+#include <tvm/ir/op.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/function.h>
+
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relay {
+
+struct FP16OpDType {
+  DataType accumulation_dtype;
+  DataType output_dtype;
+};
+
+// GREEN colored ops should always be done in FP16 due to the speed and memory savings
+// GRAY colored ops can be done in FP16 but don't have speedups to justify a dedicated cast.
+// RED colored ops should not be done in FP16 due to numerical reasons.
+enum FP16ConversionCategory { RED, GRAY, GREEN };
+
+using OpStringSet = std::unordered_set<std::string>;
+
+// Default lists inspired from TF's classifications:
+// github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h
+// They have a bias toward Nvidia Tensor Cores so modify lists per your hardware choice.
+OpStringSet DEFAULT_GREEN_LIST({
+    "nn.conv1d",
+    "nn.conv2d",
+    "nn.conv3d",
+    "nn.conv1d_transpose",
+    "nn.conv2d_transpose",
+    "nn.conv3d_transpose",
+    "nn.dense",
+    "nn.batch_matmul",
+});
+OpStringSet DEFAULT_GRAY_LIST({
+    // These ops add new data or change shape
+    "nn.pad",
+    "nn.batch_flatten",
+    "concatenate",
+    "zeros",
+    "split",
+    "squeeze",
+    "transpose",
+    "expand_dims",
+    "reshape",
+    "dyn.reshape",
+    "broadcast_to_like",
+    "dyn.broadcast_to",
+    "strided_slice",
+    "dyn.strided_slice",
+    "take",
+    "argwhere",
+    "where",
+    "tile",
+    "dyn.tile",
+    "scatter",
+    "full",
+    "dyn.full",
+    // Comparison
+    "less",
+    "greater",
+    "less_equal",
+    "greater_equal",
+    // By definition copy and cast will become green or red based on inputs
+    "copy",
+    "cast",
+    "cast_like",
+    // Simple arithmetic
+    "add",
+    "subtract",
+    "multiply",
+    "divide",
+    "nn.bias_add",
+    "nn.batch_norm",
+    "sum",
+    "mean",
+    "sqrt",
+    "shape_of",
+    // Simple activations
+    "max",
+    "min",
+    "maximum",
+    "minimum",
+    "nn.relu",
+    "nn.leaky_relu",
+    "nn.prelu",
+    "nn.dropout",
+    // Complicated activations which saturate in a narrow range
+    "sigmoid",
+    "tanh",
+    // Pooling operations
+    "nn.max_pool1d",
+    "nn.max_pool2d",
+    "nn.max_pool3d",
+    "nn.avg_pool1d",
+    "nn.avg_pool2d",
+    "nn.avg_pool3d",
+    // "nn.global_max_pool1d", // does not exist yet
+    "nn.global_max_pool2d",
+    // "nn.global_max_pool3d", // does not exist yet
+    // "nn.global_avg_pool1d", // does not exist yet
+    "nn.global_avg_pool2d",
+    // "nn.global_avg_pool3d", // does not exist yet
+    "nn.adaptive_max_pool1d",
+    "nn.adaptive_max_pool2d",
+    "nn.adaptive_max_pool3d",
+    "nn.adaptive_avg_pool1d",
+    "nn.adaptive_avg_pool2d",
+    "nn.adaptive_avg_pool3d",
+});
+OpStringSet DEFAULT_RED_LIST({
+    // In general if |f(x)| >> |x| for expected inputs then put the op here.
+    "exp",
+    "power",
+    "nn.cross_entropy",
+    "nn.cross_entropy_with_logits",
+    "nn.softmax",
+    "nn.l2_normalize",
+    // Error function doesn't seem to be able to be lowered into fp16 version in llvm.
+    // Move to gray list when it does.
+    "erf",
+});
+
+class DefaultFP16Colorer {
+  /* The default class to initially color ops for conversion using lists.
+
+  Creates a callable which given a CallNode* returns the node's color.
+  */
+ private:
+  std::unordered_map<std::string, FP16ConversionCategory> op_to_initial_color;
+
+ public:
+  DefaultFP16Colorer(OpStringSet red_list = DEFAULT_RED_LIST,
+                     OpStringSet gray_list = DEFAULT_GRAY_LIST,
+                     OpStringSet green_list = DEFAULT_GREEN_LIST) {
+    std::vector<std::pair<OpStringSet, FP16ConversionCategory>> lists_and_colors{
+        {red_list, RED}, {gray_list, GRAY}, {green_list, GREEN}};
+
+    for (auto list_and_color : lists_and_colors) {
+      OpStringSet ops = list_and_color.first;
+      FP16ConversionCategory color = list_and_color.second;
+      for (std::string op_name : ops) {
+        op_to_initial_color.insert({{op_name, color}});
+      }
+    }
+  }
+
+  FP16ConversionCategory operator()(const CallNode* call, bool ignore_missing = true) {
+    if (auto* op_node = (call->op).as<tvm::OpNode>()) {
+      std::string op_name = op_node->name;
+      auto color = op_to_initial_color.find(op_name);
+
+      if (color == op_to_initial_color.end()) {
+        if (ignore_missing) {
+          LOG(WARNING) << "Op name " << op_name << " not in included in fp16 conversion lists!";
+          return RED;
+        } else {
+          LOG(FATAL) << "Op name " << op_name << " not in included in fp16 lists!";
+        }
+      }
+
+      return color->second;
+    } else if ((call->op).as<FunctionNode>()) {
+      // Make RED to avoid messing with function headers.
+      return RED;
+    } else {
+      LOG(FATAL) << "FP16 conversion only supports call nodes with OpNodes or Functions got "
+                 << call->op;
+      return RED;
+    }
+  }
+};
+
+class DefaultFP16OpDefinition {
+  /* The default callable for determining accumulation_dtypes for ops. */
+ public:
+  FP16OpDType operator()(const CallNode* call) {
+    // TODO(AndrewZhaoLuo): remove when batch_matmul handles accumulation dtypes well.
+    // Batched matmul has inconsistent support for mixed precision operations.
+    // Many schedules ignore the out_dtype attribute which leads to errors when
+    // input types do not match the out_dtype. Therefore, accumulate to fp16 if green.
+    if (auto op_node = call->op.as<OpNode>()) {
+      if (op_node->name == "nn.batch_matmul") {
+        return {DataType::Float(16), DataType::Float(16)};
+      }
+    }

Review comment:
       This again illustrates the importance of registering casting function to each op instead of here.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org