You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/06/08 19:53:45 UTC

[GitHub] [tvm] mbrookhart commented on a change in pull request #8069: [Relay] [Pass] Add FP16 model conversion pass

mbrookhart commented on a change in pull request #8069:
URL: https://github.com/apache/tvm/pull/8069#discussion_r647748964



##########
File path: tests/python/frontend/mxnet/test_forward.py
##########
@@ -1223,6 +1221,8 @@ def verify(shape, axis=1, fix_gamma=False):
 
 @tvm.testing.uses_gpu
 def test_forward_instance_norm():
+    np.random.seed(90)
+

Review comment:
       Tianqi prefers we don't set random seeds to try to find intermittent bugs across CI runs

##########
File path: src/relay/transforms/fp32_to_fp16.cc
##########
@@ -0,0 +1,330 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file fp32_to_fp16.cc
+ * \brief Rewrite a graph into an fp16 form.
+ */
+#include "fp32_to_fp16.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    return h1 ^ (h2 << 1);

Review comment:
       If I remember correctly, xor hash combine is pretty prone to hash conflicts? Maybe use the boost approach? `return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2))`

##########
File path: src/relay/transforms/fp32_to_fp16.cc
##########
@@ -0,0 +1,332 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file fp32_to_fp16.cc
+ * \brief Rewrite a graph into an fp16 form.
+ */
+#include "fp32_to_fp16.h"
+
+#include <tvm/ir/attrs.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+
+#include "pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+
+// A callable which hashes std::pair
+struct pair_hash {
+  template <class T1, class T2>
+  std::size_t operator()(const std::pair<T1, T2>& pair) const {
+    auto h1 = std::hash<T1>()(pair.first);
+    auto h2 = std::hash<T2>()(pair.second);
+
+    return h1 ^ (h2 << 1);
+  }
+};
+
+// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype
+using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>;
+
+// A function which maps CallNodes to their initial conversion color
+using ColorFunc = std::function<FP16ConversionCategory(const CallNode*)>;
+
+// A function which maps green CallNodes to wanted accumulation and output dtypes
+using OutputDtypeFunc = std::function<FP16OpDType(const CallNode*)>;
+
+class AmpGraphCreator : public ExprMutator {
+ private:
+  CachedCastNodes cast_nodes_cache;
+  const ColorFunc colorer;
+  const OutputDtypeFunc output_dtype_func;
+
+  Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
+    /* If the accumulation dtype is in the attributes make a copy and mutate the field. */
+    Attrs new_attrs = Attrs(call->attrs);
+    if (new_attrs.get() != nullptr) {
+      // TODO(AndrewZhaoLuo): Figure out a better way to do this
+      // modify output_dtype attributes (accumulation dtypes for ops)
+      if (auto attrs = new_attrs.as<Conv1DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv1DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DeformableConv2DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DTransposeAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<Conv3DWinogradAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<DenseAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      } else if (auto attrs = new_attrs.as<BatchMatmulAttrs>()) {
+        ModifyAttrsOutputDType(attrs, accumulation_dtype);
+      }
+
+      // modify dtype attributes (creating new tensors of type dtype)
+      if (auto attrs = new_attrs.as<InitOpAttrs>()) {
+        ModifyAttrsDType(attrs, accumulation_dtype);
+      }
+    }
+
+    return new_attrs;
+  }
+
+  template <typename T>
+  void ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with out_dtype type.
+     These represent accumulation dtypes for some operations e.g.
+     conv2d might take in fp16 and give a fp32 result.
+     Attrs is const because we get it as a const.
+     */
+    T* mutable_attrs = const_cast<T*>(attrs);
+    if ((mutable_attrs->out_dtype).is_float()) mutable_attrs->out_dtype = accumulation_dtype;
+  }
+
+  template <typename T>
+  void ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const {
+    /*
+     Helper template to modify relevant attributes with dtype type.
+     This determines the output dtype for some ops. For example
+     zeros creates a tensor of zeros of the specified dtype.
+     Attrs is const because we get it as a const.
+    */
+    T* mutable_attrs = const_cast<T*>(attrs);
+    if ((mutable_attrs->dtype).is_float()) mutable_attrs->dtype = accumulation_dtype;
+  }
+
+  Type GetType(const Expr& expr) const {
+    auto mod = IRModule::FromExpr(expr);
+    mod = transform::InferType()(mod);
+
+    if (expr.as<FunctionNode>()) {
+      return mod->Lookup("main")->checked_type();
+    } else {
+      return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
+    }
+  }
+
+  bool IsFP16Type(const Type& t, bool ignore_non_float = false) const {
+    /* Returns whether t is a type with only fp16 elements.
+       If ignore_non_float, then ignore non-floating types.
+     */
+    if (const TensorTypeNode* tensor_type = t.as<TensorTypeNode>()) {
+      return (!ignore_non_float || (tensor_type->dtype).is_float()) &&
+             tensor_type->dtype == DataType::Float(16);
+    } else if (const TupleTypeNode* tuple_type = t.as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        if (!IsFP16Type(t, ignore_non_float)) return false;
+      }
+      return true;
+    } else {
+      LOG(FATAL) << "Unsupported type " << t << " we don't know how to handle";
+      return false;
+    }
+  }
+
+  Expr CachedCast(const Expr& expr, const DataType& expr_dtype, const DataType& wanted_dtype) {
+    /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */
+
+    // If this is not a floating point type, do not cast. E.g. it might be an integer
+    if (!expr_dtype.is_float()) {
+      return expr;
+    }
+
+    const ExprNode* expr_node = expr.as<ExprNode>();
+    if (!expr_node) {
+      LOG(FATAL) << "Non-expression node found in cast: " << expr;
+    }
+
+    // Use cached result if possible.
+    auto search = cast_nodes_cache.find({expr_node, wanted_dtype});
+    if (search != cast_nodes_cache.end()) {
+      return search->second;
+    }
+
+    Expr result = expr_dtype == wanted_dtype ? expr : Cast(expr, wanted_dtype);
+    cast_nodes_cache[{expr_node, wanted_dtype}] = result;
+
+    // Reverse the cache result, e.g. if we want to reverse the cast simply point to original node
+    const ExprNode* new_expr_node = result.as<ExprNode>();
+    cast_nodes_cache[{new_expr_node, expr_dtype}] = expr;
+    return result;
+  }
+
+  Expr CastArg(const Expr& expr, const Type& expr_type, const DataType& wanted_dtype) {
+    /* Helper for casting arguments to call_nodes handling all relevant cases. */
+    if (const TensorTypeNode* tensor_type = expr_type.as<TensorTypeNode>()) {
+      return CachedCast(expr, tensor_type->dtype, wanted_dtype);
+    } else if (const TupleTypeNode* tuple_type = expr_type.as<TupleTypeNode>()) {
+      Array<Expr> new_expr;
+      bool all_same = true;
+      for (size_t i = 0; i < (tuple_type->fields).size(); i++) {
+        Expr tuple_element = GetField(expr, i);
+        Type tuple_element_dtype = (tuple_type->fields)[i];
+        Expr casted_element = CastArg(tuple_element, tuple_element_dtype, wanted_dtype);
+        new_expr.push_back(casted_element);
+        all_same &= casted_element.same_as(tuple_element);
+      }
+      return all_same ? expr : Tuple(new_expr);
+    } else {
+      LOG(FATAL) << "Unsupported type " << expr_type << " we don't know how to cast for arguments!";
+      return expr;
+    }
+  }
+
+  std::pair<Array<Expr>, Array<Type>> CastAllArgs(const Array<Expr>& cur_args,
+                                                  const Array<Type>& cur_arg_types,
+                                                  const DataType& wanted_dtype) {
+    Array<Expr> new_args;
+    Array<Type> new_arg_types;
+    for (size_t i = 0; i < cur_args.size(); i++) {
+      Expr cur_arg = cur_args[i];
+      Type cur_arg_type = cur_arg_types[i];
+      Expr new_arg = CastArg(cur_arg, cur_arg_type, wanted_dtype);
+      Type new_arg_type = GetType(new_arg);
+      new_args.push_back(new_arg);
+      new_arg_types.push_back(new_arg_type);
+    }
+    return {new_args, new_arg_types};
+  }
+
+ public:
+  explicit AmpGraphCreator(ColorFunc colorer, OutputDtypeFunc output_dtype_func)
+      : ExprMutator(), colorer(colorer), output_dtype_func(output_dtype_func) {}

Review comment:
       Since you're using the recursive mutator here, you might run into stack overflows on larger models. I haven't looked at this pass in much detail yet, is it possible to do with a post-order traversal (or MixedMode Mutator?)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org