You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/10/04 20:01:46 UTC

[GitHub] [tvm] mbs-octoml commented on a change in pull request #9159: Address review comments on Arm(R) Ethos(TM)-U PR 3/6

mbs-octoml commented on a change in pull request #9159:
URL: https://github.com/apache/tvm/pull/9159#discussion_r721674546



##########
File path: python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py
##########
@@ -447,81 +455,75 @@ def _create_npu_feature_map(serial_feature_map):
     )
     nfm.layout = layout_map[layout]
     nfm.strides = vapi.NpuShape3D(
-        int(serial_feature_map.stride_h.value),
-        int(serial_feature_map.stride_w.value),
-        int(serial_feature_map.stride_c.value),
+        int(serial_feature_map.stride_h),
+        int(serial_feature_map.stride_w),
+        int(serial_feature_map.stride_c),
     )
     return nfm
 
 
-def _create_npu_kernel(serial_kernel):
+def _create_npu_kernel(serial_kernel: spec.SerialKernel) -> vapi.NpuKernel:
     """This is a helper function to capture a list
-    of arguments to create Vela NpuKernel object
+    of arguments to create Vela NpuKernel object.
     """
     nknl = vapi.NpuKernel(
-        w=int(serial_kernel.width.value),
-        h=int(serial_kernel.height.value),
-        stride_x=int(serial_kernel.stride_w.value),
-        stride_y=int(serial_kernel.stride_h.value),
-        dilation_x=int(serial_kernel.dilation_w.value),
-        dilation_y=int(serial_kernel.dilation_h.value),
+        w=int(serial_kernel.width),
+        h=int(serial_kernel.height),
+        stride_x=int(serial_kernel.stride_w),
+        stride_y=int(serial_kernel.stride_h),
+        dilation_x=int(serial_kernel.dilation_w),
+        dilation_y=int(serial_kernel.dilation_h),
     )
     return nknl
 
 
-def _create_npu_address_range(serial_address_range):
+def _create_npu_address_range(
+    serial_address_range: spec.SerialAddressRange,
+) -> vapi.NpuAddressRange:
     """This is a helper function to capture a list
-    of arguments to create Vela NpuAddressRange object
+    of arguments to create Vela NpuAddressRange object.
     """
     addr_range = vapi.NpuAddressRange(
         # region will be updated later
         region=0,
         address=serial_address_range.address,
-        length=int(serial_address_range.length.value),
+        length=int(serial_address_range.length),
     )
     return addr_range
 
 
 def _create_npu_quantization(
-    scale,
-    zero_point,
-):
+    scale: Union[tvm.tir.FloatImm, float],
+    zero_point: Union[tvm.tir.IntImm, int],
+) -> vapi.NpuQuantization:
     """This is a helper function to capture a list
-    of arguments to create Vela NpuQuantization object
+    of arguments to create Vela NpuQuantization object.
     """
-    # Scale could be an ndarray if per-channel quantization is available
-    if not isinstance(scale, tvm.tir.expr.Load):
-        if isinstance(scale.value, float):
-            scale = np.single(scale.value)
-        else:
-            assert isinstance(scale.value.value, float)
-            scale = np.single(scale.value.value)
-    q_params = vapi.NpuQuantization(scale_f32=scale, zero_point=zero_point.value)
-    return q_params
+    return vapi.NpuQuantization(scale_f32=float(scale), zero_point=int(zero_point))

Review comment:
       The only semantic change, right? Can you explain it in the PR summary, thx.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org