You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/07/13 23:45:27 UTC

[GitHub] [incubator-tvm] anijain2305 commented on a change in pull request #5980: Fixed point multiplication improvements for AArch64

anijain2305 commented on a change in pull request #5980:
URL: https://github.com/apache/incubator-tvm/pull/5980#discussion_r453998267



##########
File path: include/tvm/relay/attrs/transform.h
##########
@@ -298,6 +298,17 @@ struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> {
   }
 };
 
+/*! \brief Attributes for FixedPointMultiply operator */
+struct FixedPointMultiplyAttrs : public tvm::AttrsNode<FixedPointMultiplyAttrs> {
+  int32_t multiplier;
+  int32_t shift;
+
+  TVM_DECLARE_ATTRS(FixedPointMultiplyAttrs, "relay.attrs.FixedPointMultiplyAttrs") {
+    TVM_ATTR_FIELD(multiplier).describe("Integer multiplier.");

Review comment:
       Nit, but lets remove the period at the end to be consistent with others.
   might be good to describe the multiplier and shift briefly

##########
File path: include/tvm/tir/op.h
##########
@@ -552,6 +552,24 @@ TVM_DLL PrimExpr trunc(PrimExpr x);
  */
 TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high);
 
+/*!
+ * \brief Execute a multiplication between two Q-numbers x and y
+ * followed by a right shift s. The mathematical expression is:
+ *
+ *    out = round(x*y*2^-s)
+ *
+ * More about Q-numbers here: https://en.wikipedia.org/wiki/Q_(number_format)
+ *
+ * The rounding rule is to the nearest value, rounding half up
+ * (i.e., round(x.1) = x and round (x.5) = x+1)
+ * \param x first Q-number
+ * \param y second Q-number
+ * \param q Q-ness of x and y

Review comment:
       Agreed, number of fractional bits is better description

##########
File path: src/relay/op/tensor/unary.cc
##########
@@ -274,6 +274,20 @@ TVM_REGISTER_GLOBAL("relay.op._make.clip").set_body_typed([](Expr a, double a_mi
   return Call(op, {a}, Attrs(attrs), {});
 });
 
+// relay.fixed_point_multiply
+TVM_REGISTER_NODE_TYPE(FixedPointMultiplyAttrs);
+
+RELAY_REGISTER_OP("fixed_point_multiply")
+    .describe(R"code( fixed point multiplication )code" TVM_ADD_FILELINE)
+    .set_num_inputs(1)
+    .add_argument("data", "Tensor", "The input tensor.")
+    .add_type_rel("Identity", IdentityRel)
+    .set_attr<TOpPattern>("TOpPattern", kElemWise)
+    .set_attr<TOpIsStateful>("TOpIsStateful", false)
+    .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
+    .set_attrs_type<FixedPointMultiplyAttrs>()
+    .set_support_level(3);

Review comment:
       I think level 10 is better here
   
   @tqchen any suggestions here?

##########
File path: topi/python/topi/arm_cpu/injective.py
##########
@@ -62,9 +62,13 @@ def schedule_injective(outs):
     outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
     s = te.create_schedule([x.op for x in outs])
     x = outs[0]
+    ins = x.op.input_tensors
+    dtype = ins[0].dtype if len(ins) > 0 else x.dtype
+    max_vlen = 4 if dtype == 'int32' else 8

Review comment:
       Seems like 4 should be better for float32 as well. If it is, then maybe we should always use 4 instead of 8.

##########
File path: python/tvm/tir/op.py
##########
@@ -965,6 +965,34 @@ def popcount(x):
     """
     return call_intrin(x.dtype, "tir.popcount", x)
 
+def qmuls(x, y, q, s):
+    """Execute a multiplication between two Q-numbers x and y
+    followed by a right shift s. The mathematical expression is:
+
+       out = round(x*y*2^-s)

Review comment:
       Maybe we should add a line to explain why there is a multiplication factor of 2 (perhaps rounding)

##########
File path: include/tvm/tir/op.h
##########
@@ -552,6 +552,24 @@ TVM_DLL PrimExpr trunc(PrimExpr x);
  */
 TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high);
 
+/*!
+ * \brief Execute a multiplication between two Q-numbers x and y
+ * followed by a right shift s. The mathematical expression is:
+ *
+ *    out = round(x*y*2^-s)
+ *
+ * More about Q-numbers here: https://en.wikipedia.org/wiki/Q_(number_format)
+ *
+ * The rounding rule is to the nearest value, rounding half up
+ * (i.e., round(x.1) = x and round (x.5) = x+1)
+ * \param x first Q-number
+ * \param y second Q-number
+ * \param q Q-ness of x and y

Review comment:
       Are number of fractional bits same for x and y and thats why we need only input? Lets make the description more clear.
   IIUC, one can think of it using this op to perform something like Q1.31  * Q2.30. But, I think this op is restrictive than that. If it is, then lets mention it.

##########
File path: include/tvm/tir/builtin.h
##########
@@ -92,6 +92,14 @@ TVM_DLL const Op& shift_right();
  */
 TVM_DLL const Op& large_uint_imm();
 
+/*!
+ * \brief Execute a multiplication between two Q-numbers x and y
+ * followed by a right shift s
+ * The default rounding rule is to the nearest value, rounding half up
+ * (i.e., round(x.1) = x and round (x.5) = x+1)
+ */
+TVM_DLL const Op& qmuls();

Review comment:
       We should come up with a better name. Currently, `qmuls` seems vague.
   Not sure what `q` and `s` stand for a person not familiar with Q numbers.
   
   Why not use the same `fixed_point_multiply`?

##########
File path: src/target/intrin_rule.cc
##########
@@ -115,6 +115,51 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.isinf")
       *rv = isinf(call->args[0]);
     });
 
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.qmuls")
+    .set_body([](const TVMArgs& args, TVMRetValue* rv) {
+      using tir::make_const;
+
+      PrimExpr e = args[0];
+      const tir::CallNode* call = e.as<tir::CallNode>();
+      CHECK(call != nullptr);
+
+      PrimExpr x = call->args[0];
+      PrimExpr y = call->args[1];
+      PrimExpr q = call->args[2];
+      PrimExpr s = call->args[3];
+
+      // Only int32 types are supported (any number of lanes is allowed)
+      CHECK(x.dtype().code() == DLDataTypeCode::kDLInt && x.dtype().bits() == 32);
+      CHECK(y.dtype().code() == DLDataTypeCode::kDLInt && y.dtype().bits() == 32);
+      CHECK(s.dtype().code() == DLDataTypeCode::kDLInt && s.dtype().bits() == 32);
+
+      DataType hp_dtype = DataType::Int(64, x.dtype().lanes());
+      DataType lp_dtype = DataType::Int(32, x.dtype().lanes());
+
+      // 1) Calculating the integer multiplier and integer shift
+      PrimExpr zero = make_const(s.dtype(), 0);
+      PrimExpr left_shift = tir::Select((s > zero), s, zero);
+      PrimExpr right_shift = tir::Select(s > zero, zero, -s);
+
+      // 2) Multiply the integer multiplier
+      x = tir::Select(left_shift != zero, x << cast(hp_dtype, left_shift), cast(hp_dtype, x));
+
+      // 3) Perform the multiplication in higher precision.
+      x = x * y;

Review comment:
       Do we need to cast y to `hp_dtype`?

##########
File path: src/relay/op/tensor/unary.cc
##########
@@ -274,6 +274,20 @@ TVM_REGISTER_GLOBAL("relay.op._make.clip").set_body_typed([](Expr a, double a_mi
   return Call(op, {a}, Attrs(attrs), {});
 });
 
+// relay.fixed_point_multiply
+TVM_REGISTER_NODE_TYPE(FixedPointMultiplyAttrs);
+
+RELAY_REGISTER_OP("fixed_point_multiply")
+    .describe(R"code( fixed point multiplication )code" TVM_ADD_FILELINE)

Review comment:
       No need of space between `(` and `fixed`. Similarly at the end.

##########
File path: topi/python/topi/arm_cpu/tensor_intrin.py
##########
@@ -451,3 +451,55 @@ def _instr(index):
     return te.decl_tensor_intrin(
         C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer},
         default_buffer_params=buffer_params)
+
+def _qmuls_arm(op):
+    """
+    Implementation of qmuls through arm intrinsics sqrdmulh and srshl
+    when q == 31.
+
+    Please note that this is introducing a small round-up error for
+    some corner cases. This is because we are rounding twice instead
+    than only once. I.e.:
+
+        * original qmuls: round(x*y*2^-s)
+        * arm qmuls: round(round(x*y)*2^-s)
+    """
+    x = op.args[0]
+    y = op.args[1]
+    q = op.args[2]
+    s = op.args[3]
+
+    # Don't use this intrinsic if we don't have a int32x4 vector
+    # and if we are not multiplying q31 numbers
+    if x.dtype != "int32x4" and q == 31:

Review comment:
       Can you please double check the condition? Should there be or here?

##########
File path: topi/python/topi/math.py
##########
@@ -612,6 +612,31 @@ def _compute(*indices):
         return tvm.te.max(tvm.te.min(value, const_max), const_min)
     return te.compute(x.shape, _compute)
 
+@tvm.te.tag_scope(tag=tag.ELEMWISE)
+def fixed_point_multiply(x, multiplier, shift):
+    """
+
+    Parameters
+    ----------
+    x :          tvm.te.Tensor or Expr

Review comment:
       Lets be consistent with the doc format. You can take a look at the example just above this.

##########
File path: src/relay/qnn/op/requantize.cc
##########
@@ -153,9 +153,19 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
         static_cast<double>(input_scale_float) / static_cast<double>(output_scale_float);
     // Skip if input and output scales are same.
     if (!IsEqualScalar(input_scale, output_scale)) {
-      scaled_int32_t =
-          FixedPointMultiply(scaled_int32_t, double_multiplier, input_shape, param->rounding);
+      int32_t fixed_point_multiplier, shift;
+      std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(double_multiplier);
+
+      const bool is_upward_rounding = (param->rounding == "UPWARD");
+
+      // When using upward rounding (i.e., x.5 rounded to x+1), leverage
+      // the fixed_point_muliply intrinsic
+      scaled_int32_t = (is_upward_rounding ? relay::FixedPointMultiply(
+                                                 scaled_int32_t, fixed_point_multiplier, shift)
+                                           : FixedPointMultiply(scaled_int32_t, double_multiplier,

Review comment:
       We should remove the UPWARD rounding code from the util non-Relay FixedPointMultiply function as that code will not be executed now.
   
   Also, it is confusing now to why we have two same function names :) We should rename the already existing util function.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org