You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/12/07 11:46:05 UTC

[tvm] branch main updated: [microNPU] Support different constant datatypes (#9626)

This is an automated email from the ASF dual-hosted git repository.

manupa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 224f8da  [microNPU] Support different constant datatypes (#9626)
224f8da is described below

commit 224f8dac7c27e81933ad8815c1a5ab61d4853403
Author: lhutton1 <35...@users.noreply.github.com>
AuthorDate: Tue Dec 7 11:45:42 2021 +0000

    [microNPU] Support different constant datatypes (#9626)
    
    Currently only uint8 datatype is supported for constants, as this is
    all that was necessary until now. This PR allows different datatypes
    to be used for constants, including different datatypes within the
    same graph.
    
    A workaround was previously added for Mean legalization, this has
    also been removed and replaced with the expected datatype of the
    constant.
---
 .../tvm/relay/backend/contrib/ethosu/legalize.py   |  2 +-
 .../backend/contrib/ethosu/tir_to_cs_translator.py | 46 ++++++++--------
 src/relay/op/contrib/ethosu/binary_elementwise.cc  | 23 ++------
 tests/python/contrib/test_ethosu/test_codegen.py   | 10 ++--
 .../test_ethosu/test_tir_to_cs_translator.py       | 62 +++++++++++++++++++---
 5 files changed, 90 insertions(+), 53 deletions(-)

diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py
index 9b6c924..f8beb7f 100644
--- a/python/tvm/relay/backend/contrib/ethosu/legalize.py
+++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py
@@ -1040,7 +1040,7 @@ class MeanRewriter(DFPatternCallback):
             n = int(filter_height * filter_width)
             eps = 1 / (256 * (n + 1)) if n % 2 == 0 else 0
 
-            scalar_tensor = relay.const(np.ones([1, 1, 1, 1], dtype="uint8"), dtype="uint8")
+            scalar_tensor = relay.const(np.ones([1, 1, 1, 1], dtype="int16"), dtype="int16")
 
             reduced_op = ethosu_ops.ethosu_binary_elementwise(
                 ifm=reduced_op,
diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py
index 638157f..464a3bd 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py
@@ -110,12 +110,11 @@ def translate(tir_module, params):
     _npu_ops = list()
     for call_extern in call_extern_list:
         _npu_ops.append(translate_ethosu_tir_call_extern(call_extern))
-    _npu_ops, constant_tensor, scratch_size = assign_addresses(buffer_info, _npu_ops)
+    _npu_ops, constant_data, scratch_size = assign_addresses(buffer_info, _npu_ops)
     target_accel_config = vela_api.get_accelerator_config()
     cmds = vapi.npu_generate_register_command_stream(_npu_ops, target_accel_config)
     payload = vapi.npu_create_driver_payload(cmds, target_accel_config)
-    hex_value = "" if constant_tensor is None else constant_tensor.tobytes().hex()
-    return payload.hex(), hex_value, scratch_size
+    return payload.hex(), constant_data, scratch_size
 
 
 def extract_call_extern_list(mod):
@@ -277,27 +276,24 @@ def assign_addresses(buffer_info, npu_ops):
         raise ValueError(f"Unused IO : {buffer} in tir module.")
 
     scratch_size = 0
-    constant_tensor = None
+    constant_hex_data = []
+    total_constant_len = 0
     buffer_addresses = dict()
     for _buffer, info in buffer_info.items():
         if info.values is not None:
-            assert np.dtype(info.dtype) == np.uint8
             assert info.btype == BufferType.constant
             assert len(info.shape) == 1
-            if constant_tensor is None:
-                buffer_addresses[_buffer] = (0, info.btype)
-                assert info.values.dtype == np.uint8
-                size_in_bytes = info.values.size
-                # Every memory address the NPU access have to be 16 byte aligned
-                size_in_bytes = util.round_up(size_in_bytes, 16)
-                constant_tensor = np.resize(info.values, size_in_bytes)
-            else:
-                buffer_addresses[_buffer] = (constant_tensor.size, info.btype)
-                assert info.values.dtype == np.uint8
-                size_in_bytes = info.values.size
-                # Every memory address the NPU access have to be 16 byte aligned
-                size_in_bytes = util.round_up(size_in_bytes, 16)
-                constant_tensor = np.append(constant_tensor, np.resize(info.values, size_in_bytes))
+            buffer_addresses[_buffer] = (
+                (total_constant_len, info.btype) if constant_hex_data else (0, info.btype)
+            )
+            dtype_bytes = np.iinfo(np.dtype(info.dtype)).bits // 8
+            size_in_bytes = dtype_bytes * np.prod(list(info.shape))
+            # Every memory address the NPU access have to be 16 byte aligned
+            size_in_bytes = util.round_up(size_in_bytes, 16)
+            constant_tensor = np.resize(info.values, size_in_bytes // dtype_bytes)
+            constant_tensor = constant_tensor.tobytes().hex()
+            constant_hex_data.append(constant_tensor)
+            total_constant_len += len(constant_tensor) // 2
         else:
             if info.btype == BufferType.input_or_output:
                 buffer_type = classify_io(_buffer)
@@ -310,9 +306,8 @@ def assign_addresses(buffer_info, npu_ops):
                 address = arch_config.lut_start_address
                 buffer_addresses[_buffer] = (address, info.btype)
             else:
-                size_in_bytes = int(
-                    (np.iinfo(np.dtype(info.dtype)).bits // 8) * np.prod(list(info.shape))
-                )
+                dtype_bytes = np.iinfo(np.dtype(info.dtype)).bits // 8
+                size_in_bytes = int(dtype_bytes * np.prod(list(info.shape)))
                 # Every memory address the NPU access have to be 16 byte aligned
                 size_in_bytes = util.round_up(size_in_bytes, 16)
                 assert info.btype == BufferType.scratch
@@ -330,7 +325,12 @@ def assign_addresses(buffer_info, npu_ops):
             else:
                 setattr(npu_op, attr_name, replace_tir_loads(attr))
 
-    return npu_ops, constant_tensor, scratch_size
+    constant_data = "".join(constant_hex_data)
+    return (
+        npu_ops,
+        constant_data,
+        scratch_size,
+    )
 
 
 def translate_ethosu_tir_call_extern(tir_call_extern):
diff --git a/src/relay/op/contrib/ethosu/binary_elementwise.cc b/src/relay/op/contrib/ethosu/binary_elementwise.cc
index 4e0d086..e762245 100644
--- a/src/relay/op/contrib/ethosu/binary_elementwise.cc
+++ b/src/relay/op/contrib/ethosu/binary_elementwise.cc
@@ -128,21 +128,6 @@ struct EthosuBinaryElementwiseAttrs : public tvm::AttrsNode<EthosuBinaryElementw
 
 TVM_REGISTER_NODE_TYPE(EthosuBinaryElementwiseAttrs);
 
-bool IsScalarTensor(const Array<PrimExpr>& ifm_shape, const DataType& ifm_dtype) {
-  if (ifm_dtype != DataType::UInt(8)) {
-    return false;
-  }
-
-  for (const auto& expr : ifm_shape) {
-    const auto& dim_int_node = expr.as<IntImmNode>();
-    CHECK(dim_int_node) << "Expected IntImmNode for shape dimensions.";
-    int dim = dim_int_node->value;
-    if (dim != 1) return false;
-  }
-
-  return true;
-}
-
 bool EthosuBinaryElementwiseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
                                 const TypeReporter& reporter) {
   const int ifm_index = 0;
@@ -167,11 +152,13 @@ bool EthosuBinaryElementwiseRel(const Array<Type>& types, int num_inputs, const
     ofm_dtype = DataType::Int(8);
   } else if (param->ofm_dtype == "uint8") {
     ofm_dtype = DataType::UInt(8);
+  } else if (param->ofm_dtype == "int16") {
+    ofm_dtype = DataType::Int(16);
   } else if (param->ofm_dtype == "int32") {
     ofm_dtype = DataType::Int(32);
   }
 
-  if (ifm_dtype != ifm2_dtype && !IsScalarTensor(ifm2->shape, ifm2_dtype)) {
+  if (ifm_dtype != ifm2_dtype) {
     reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
                                      << "Invalid operator: expected ethosu_binary_elementwise "
                                      << "type for ifm2 be the same of ifm but was " << ifm2_dtype
@@ -189,11 +176,11 @@ bool EthosuBinaryElementwiseRel(const Array<Type>& types, int num_inputs, const
       return false;
     }
     if (ofm_dtype != DataType::UInt(8) && ofm_dtype != DataType::Int(8) &&
-        ofm_dtype != DataType::Int(32)) {
+        ofm_dtype != DataType::Int(16) && ofm_dtype != DataType::Int(32)) {
       reporter->GetDiagCtx().EmitFatal(
           Diagnostic::Error(reporter->GetSpan())
           << "Invalid operator: expected ethosu_binary_elementwise " << operator_type
-          << " type(uint8) or type(int8) or type(int32) for ofm but was " << ofm_dtype);
+          << " type(uint8), type(int8), type(int16) or type(int32) for ofm but was " << ofm_dtype);
       return false;
     }
   } else if (operator_type == "MIN" || operator_type == "MAX") {
diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py
index 6925f92..0e55487 100644
--- a/tests/python/contrib/test_ethosu/test_codegen.py
+++ b/tests/python/contrib/test_ethosu/test_codegen.py
@@ -520,8 +520,8 @@ def test_mean(accel_type, ifm_shape, axis, keep_dims, use_same_quantization):
 
 
 @pytest.mark.parametrize("accel_type", ACCEL_TYPES)
-def test_binary_add_from_constant_scalar(accel_type):
-    dtype = "uint8"
+@pytest.mark.parametrize("dtype", ["int8", "uint8"])
+def test_elementwise_add_from_constant_scalar(accel_type, dtype):
     ifm_shape = (1, 4, 4, 8)
 
     def create_relay_graph():
@@ -543,7 +543,11 @@ def test_binary_add_from_constant_scalar(accel_type):
     ethosu_mod = partition_for_ethosu(cpu_mod)
 
     # Generate reference data
-    input_data = {"input": np.random.randint(low=0, high=255, size=ifm_shape, dtype=dtype)}
+    input_data = {
+        "input": np.random.randint(
+            low=np.iinfo(dtype).min, high=np.iinfo(dtype).max, size=ifm_shape, dtype=dtype
+        ),
+    }
     output_data = generate_ref_data(cpu_mod, input_data)
 
     _compare_ethosu_with_reference(
diff --git a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py
index 59b7b2c..c14deb6 100644
--- a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py
+++ b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py
@@ -651,6 +651,31 @@ def test_translate_ethosu_copy():
             assert npu_dma_op.dest.length == test_case["ref"][idx]["length"]
 
 
+# fmt: off
+@tvm.script.ir_module
+class MixedConstantDatatypes:
+    @T.prim_func
+    def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, ethosu_write: T.handle, placeholder_3: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
+        placeholder_4 = T.match_buffer(placeholder, [1, 8, 16, 16], dtype="int8")
+        buffer = T.match_buffer(placeholder_1, [160], dtype="uint8")
+        placeholder_5 = T.match_buffer(placeholder_2, [1, 1, 1, 1], dtype="int16")
+        ethosu_write_1 = T.match_buffer(ethosu_write, [1, 1, 1, 16], dtype="int8")
+        buffer_1 = T.match_buffer(placeholder_3, [272], dtype="uint8")
+        # body
+        placeholder_global = T.allocate([272], "uint8", "global")
+        placeholder_d_global = T.allocate([160], "uint8", "global")
+        ethosu_write_2 = T.allocate([16], "int16", "global")
+        placeholder_d_global_1 = T.allocate([1], "int16", "global")
+        T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1.data, 0), 272, T.load("uint8", placeholder_global, 0), dtype="uint8"))
+        T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer.data, 0), 160, T.load("uint8", placeholder_d_global, 0), dtype="uint8"))
+        T.evaluate(T.call_extern("ethosu_depthwise_conv2d", "int8", 8, 16, 16, 8, 0, 16, T.load("int8", placeholder_4.data, 0), 0, 0, 0, T.float32(0.0039215548895299435), -128, "NHWC", 256, 16, 1, "int16", 1, 1, 16, 1, 0, 1, T.load("int16", ethosu_write_2, 0), 0, 0, 0, T.float32(0.0023205536417663097), -128, "NHWC", 1, 1, 1, 16, 8, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 272, 0, T.load("uint8", placeholder_d_global, 0), 160, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="int16"))
+        T.evaluate(T.call_extern("ethosu_copy", T.load("int16", placeholder_5.data, 0), 1, T.load("int16", placeholder_d_global_1, 0), dtype="int16"))
+        T.evaluate(T.call_extern("ethosu_binary_elementwise", "int16", 1, 1, 16, 1, 0, 1, T.load("int16", ethosu_write_2, 0), 0, 0, 0, T.float32(0.0023205536417663097), -128, "NHWC", 1, 1, 1, "int16", 1, 1, 1, 1, 0, 1, T.load("int16", placeholder_d_global_1, 0), 0, 0, 0, T.float32(0.0078125018482064768), 0, "NHWC", 1, 1, 1, "int8", 1, 1, 16, 1, 0, 1, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.0023205536417663097), -128, "NHWC", 1, 1, 1, "MUL", 0, "NONE", 0, 0, "NATURAL" [...]
+# fmt: on
+
+
 def test_assign_addresses():
     test_cases = [
         {
@@ -683,6 +708,15 @@ def test_assign_addresses():
                 11: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [20], "uint8"),
             },
         },
+        {
+            # Stimulus
+            "tir_module": MixedConstantDatatypes,
+            "param_dict": {
+                1: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [160], "uint8"),
+                2: np.random.randint(np.iinfo("int16").min, np.iinfo("int16").max, [1], "int16"),
+                4: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [272], "uint8"),
+            },
+        },
     ]
 
     def extract_call_extern_list(mod):
@@ -747,24 +781,36 @@ def test_assign_addresses():
             4: tir_to_cs_translator.BufferType.output,
         }
         buffer_type = inverse_region_map[region]
+        buffer_dtype = buffer_var.type_annotation.element_type.dtype
+        dtype_bytes = np.iinfo(np.dtype(buffer_dtype)).bits // 8
         if buffer_type == tir_to_cs_translator.BufferType.constant:
             ref = buffer_info[buffer_var].values
-            assert (constant_tensor[address : address + length] == ref).all()
+            hex_from = address * dtype_bytes * 2
+            hex_to = hex_from + length * dtype_bytes * 2
+            constant_hex = constant_hex_string[hex_from:hex_to]
+            constant_tensor = np.frombuffer(bytearray.fromhex(constant_hex), dtype=buffer_dtype)
+            np.array_equal(constant_tensor, ref)
             # Every buffer is adjusted to align to 16 bytes
             length = util.round_up(length, 16)
             # Mark these constants are read at least once
-            constant_tensor_read_mask[address : address + length] = np.ones(length, dtype="uint8")
+            constant_tensor_read_mask[address : address + length] = np.ones(
+                length, dtype=buffer_dtype
+            )
         elif buffer_type == tir_to_cs_translator.BufferType.scratch:
             shape = list(buffer_info[buffer_var].shape)
             assert length == np.prod(shape)
             assert address < scratch_size
+
+            size_in_bytes = int(np.prod(shape)) * dtype_bytes
             # Every buffer is adjusted to align to 16 bytes
-            length = util.round_up(length, 16)
-            assert address + length <= scratch_size
+            size_in_bytes = util.round_up(size_in_bytes, 16)
+            assert address + size_in_bytes <= scratch_size
             # The scratch area should not be used by anyother buffer
-            assert not scratch_allocation_mask[address : address + length].any()
+            assert not scratch_allocation_mask[address : address + size_in_bytes].any()
             # The scratch area is marked as used
-            scratch_allocation_mask[address : address + length] = np.ones(length, dtype="uint8")
+            scratch_allocation_mask[address : address + size_in_bytes] = np.ones(
+                size_in_bytes, dtype="uint8"
+            )
         elif buffer_type == tir_to_cs_translator.BufferType.input:
             assert address == 0
         else:
@@ -841,11 +887,11 @@ def test_assign_addresses():
         for extern_call in extern_calls:
             _npu_ops.append(tir_to_cs_translator.translate_ethosu_tir_call_extern(extern_call))
         npu_op_tir_buffers = collect_tir_buffer_info(_npu_ops)
-        _npu_ops, constant_tensor, scratch_size = tir_to_cs_translator.assign_addresses(
+        _npu_ops, constant_hex_string, scratch_size = tir_to_cs_translator.assign_addresses(
             buffer_info, _npu_ops
         )
         scratch_allocation_mask = np.zeros(scratch_size, dtype="uint8")
-        constant_tensor_read_mask = np.zeros(constant_tensor.size, dtype="uint8")
+        constant_tensor_read_mask = np.zeros(len(constant_hex_string) // 2, dtype="uint8")
         verify(_npu_ops)
         # This will be only 1 if all allocated scratch is used.
         assert np.prod(scratch_allocation_mask) == 1