You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by "tkonolige (via GitHub)" <gi...@apache.org> on 2023/03/24 15:55:22 UTC

[GitHub] [tvm] tkonolige commented on a diff in pull request #14382: [Rewrite] Fix rewriting division to constant to handle vector case

tkonolige commented on code in PR #14382:
URL: https://github.com/apache/tvm/pull/14382#discussion_r1147776041


##########
src/relay/transforms/div_to_mul.cc:
##########
@@ -26,42 +26,61 @@
 namespace tvm {
 namespace relay {
 
+template <typename T>
+inline bool const_has_values(size_t size, const ConstantNode* const_node,
+                             const std::vector<T>&& values) {
+  for (size_t i = 0; i < size; i++) {
+    T data = static_cast<T*>(const_node->data->data)[i];
+    for (const T& v : values) {
+      if (data == v) return true;
+    }
+  }
+  return false;
+}
+
+inline size_t get_num_elements_const(const ConstantNode* const_node) {
+  const auto& shape = const_node->data.Shape();
+
+  size_t cnt_elements = 1;
+  for (const auto& dim : shape) {
+    cnt_elements *= dim;
+  }
+
+  return cnt_elements;
+}
+
 class DivToMulRewrite : public MixedModeMutator {
   Expr Rewrite_(const CallNode* pre, const Expr& post) final {
     if (const CallNode* call_node = post.as<CallNode>()) {
       if (call_node->op == Op::Get("divide")) {
         auto rhs = call_node->args[1].as<ConstantNode>();
         if (rhs != nullptr) {
-          auto inv =
-              runtime::NDArray::Empty(rhs->data.Shape(), rhs->data.DataType(), rhs->data->device);
+          auto one = runtime::NDArray::Empty({}, rhs->data.DataType(), rhs->data->device);
+          size_t num_ele = get_num_elements_const(rhs);
           std::string dtype = DLDataType2String(rhs->data.DataType());
+
+          bool const_has_zero_flag = false;
           if (dtype == "float32") {
-            float rhs_val = static_cast<float*>(rhs->data->data)[0];
-            // Check for division by zero
-            if (rhs_val == 0.) {
-              return post;
-            }
-            static_cast<float*>(inv->data)[0] = 1. / rhs_val;
+            static_cast<float*>(one->data)[0] = 1.;
+            const_has_zero_flag = const_has_values<float>(num_ele, rhs, {0.});
           } else if (dtype == "float64") {
-            double rhs_val = static_cast<double*>(rhs->data->data)[0];
-            // Check for division by zero
-            if (rhs_val == 0.) {
-              return post;
-            }
-            static_cast<double*>(inv->data)[0] = 1. / rhs_val;
+            static_cast<double*>(one->data)[0] = 1.;
+            const_has_zero_flag = const_has_values<double>(num_ele, rhs, {0.});
           } else if (dtype == "float16") {
-            // Do f16 math in f32
-            float rhs_val = __gnu_h2f_ieee(static_cast<uint16_t*>(rhs->data->data)[0]);
-            // Check for division by zero
-            if (rhs_val == 0.) {
-              return post;
-            }
-            static_cast<uint16_t*>(inv->data)[0] = __gnu_f2h_ieee(1. / rhs_val);
+            static_cast<uint16_t*>(one->data)[0] = __gnu_f2h_ieee(1.);
+            // have to handle both + and - zero semantics manually here
+            const_has_zero_flag = const_has_values<uint16_t>(num_ele, rhs, {0x0000, 0x8000});
           } else {
-            // Cannot do 1/int because it will truncate
+            LOG(WARNING) << "Unknown dtype not handled for div_to_mull: " << rhs->data.DataType();
             return post;
           }
-          return Multiply(call_node->args[0], Constant(inv));
+
+          if (const_has_zero_flag) {
+            return post;
+          }
+
+          // rely on constant folding to fold things

Review Comment:
   Are you sure constant folding is going to fire at the right time? I can't remember if the divisions need to be rewritten before FakeQuantization or if it doesn't matter when.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org