You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/02/04 12:01:02 UTC

[GitHub] [tvm] manupa-arm commented on a change in pull request #10022: [microNPU] enable USMP

manupa-arm commented on a change in pull request #10022:
URL: https://github.com/apache/tvm/pull/10022#discussion_r799405283



##########
File path: python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py
##########
@@ -81,6 +72,107 @@ def get_accelerator_arch_config(accel_type):
     return accel_config_str_map[accel_type]
 
 
+class RegionOffset(NamedTuple):
+    """A data structure to hold region and address offset corresponding to a tensor"""
+
+    region: int
+    offset: int
+
+
+def analyze_scratch_memory_acesses(mod: tvm.IRModule, candidate_regions_for_scratch: List[int]):
+    """
+    Parameters
+    ----------
+    mod: tvm.IRModule
+        The TIR module containing ethosu extern calls
+    candidate_regions_for_scratch: List[int]
+        A list of region integers that could be used for scratch regions
+
+    Returns
+    -------
+    scratch_region_map : Dict[tvm.tir.Var, int]
+        A map between buffer vars to scratch regions they are assigned
+    tvm_backend_alloc_workspace_size : int
+        The size of tvm_backend_alloc_workspace call required to service
+        remaining allocate nodes if any
+    tvm_backend_alloc_workspace_region : int
+        The region associated with the tvm_backend_alloc_workspace
+    """
+    scratch_region_map = dict()
+    pool_var_region_map = dict()
+    # There should only be a single function
+    assert len(mod.functions.items()) == 1
+    primfunc = mod.functions.items()[0][1]
+    if "pool_args" in primfunc.attrs.keys():
+        pool_args = primfunc.attrs["pool_args"]
+        for pool_arg in pool_args:
+            pool_param = primfunc.params[int(pool_arg.pool_var_idx)]
+            pool_var_region_map[pool_param] = candidate_regions_for_scratch.pop()
+            scratch_region_map[pool_param] = RegionOffset(
+                region=pool_var_region_map[pool_param], offset=None
+            )
+
+    def analyze_pool_access(stmt):
+        if isinstance(stmt, tvm.tir.stmt.LetStmt):
+            call_address_of = stmt.value
+            load = call_address_of.args[0]
+            pool_var = load.buffer_var
+            scratch_region_map[stmt.var] = RegionOffset(
+                region=pool_var_region_map[pool_var], offset=int(load.index)
+            )
+
+    tvm.tir.stmt_functor.post_order_visit(primfunc.body, analyze_pool_access)
+
+    tvmbaw_region = None
+    if len(candidate_regions_for_scratch) > 0:
+        tvmbaw_region = candidate_regions_for_scratch.pop()
+
+        # Need a mutable data structure to be updated by the following function
+        # Therefore, using a list instead of int
+        tvmbaw_size = [0]

Review comment:
       Cool! sounds like a good idea




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org