You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/12/07 13:44:48 UTC

[GitHub] [tvm] lhutton1 commented on a change in pull request #9418: [TIR][USMP] adding the pass to convert to pool offsets

lhutton1 commented on a change in pull request #9418:
URL: https://github.com/apache/tvm/pull/9418#discussion_r763975438



##########
File path: python/tvm/tir/usmp/transform/transform.py
##########
@@ -0,0 +1,45 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""USMP Transform Python API for passes"""
+# pylint: disable=invalid-name
+
+from typing import Dict
+
+from . import _ffi_api
+from ....tir import Stmt

Review comment:
       Might be better to use absolute imports here?

##########
File path: python/tvm/tir/usmp/transform/transform.py
##########
@@ -0,0 +1,45 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""USMP Transform Python API for passes"""
+# pylint: disable=invalid-name
+
+from typing import Dict
+
+from . import _ffi_api
+from ....tir import Stmt
+from ..utils import PoolAllocation
+
+
+def convert_pool_allocations_to_offsets(
+    pool_allocations: Dict[Stmt, PoolAllocation], emit_tvmscript_printable: bool = False
+):

Review comment:
       Nit: Missing return type hint

##########
File path: tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py
##########
@@ -0,0 +1,517 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import sys
+
+import tvm
+from tvm.script import tir as T
+from tvm.tir import stmt_functor
+from tvm.tir.usmp import utils as usmp_utils
+from tvm.target import Target
+
+
+def _get_primfuncs_from_module(module):
+    primfuncs = list()
+    for gv, primfunc in module.functions.items():
+        primfuncs.append(primfunc)
+    return primfuncs
+
+
+def assign_poolinfos_to_allocates_in_primfunc(primfunc, pool_infos):
+    """helper to assing poolinfos to allocate nodes in a tir.PrimFunc"""

Review comment:
       ```suggestion
       """Helper to assign poolinfos to allocate nodes in a tir.PrimFunc"""
   ```
   
   And similar below

##########
File path: src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc
##########
@@ -0,0 +1,351 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/analysis/usmp/transform/convert_pool_allocations_to_offsets.cc
+ * \brief This pass would convert the pool allocations to offsets from pools
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+#include <tvm/tir/usmp/utils.h>
+
+#include <stack>
+
+namespace tvm {
+namespace tir {
+namespace usmp {
+
+/*!
+ * \brief The StmtExpr mutator class to replace allocate nodes
+ * with offsets within memory pools
+ *
+ * This mutator class with add Pool variables recursively to every PrimFunc
+ * starting from the main PrimFunc. For all allocate nodes, that have been
+ * memory planned, will be mutated into an offset using a Let binding.
+ */
+class PoolAllocationToOffsetConverter : public StmtExprMutator {
+ public:
+  PoolAllocationToOffsetConverter(const IRModule& module,
+                                  const Map<tir::Stmt, PoolAllocation>& pool_allocations,
+                                  bool emit_tvmscript_printable = false)
+      : pool_allocations_(pool_allocations), emit_tvmscript_printable_(emit_tvmscript_printable) {
+    module_ = module->ShallowCopy();
+    for (const auto& kv : pool_allocations) {
+      // TODO(@manupa-arm): add AllocateConstNode when it is available
+      ICHECK(kv.first->IsInstance<AllocateNode>());
+      Allocate allocate_node = Downcast<Allocate>(kv.first);
+      PoolAllocation pool_allocation = kv.second;
+      PoolInfo pool_info = pool_allocation->pool_info;
+      int byte_pool_offset = pool_allocation->byte_offset->value;
+      int required_pool_size_for_allocation =
+          byte_pool_offset + CalculateExtentsSize(allocate_node.operator->());
+      if (all_pools_sizes_.find(pool_info) == all_pools_sizes_.end()) {
+        all_pools_sizes_[pool_info] = required_pool_size_for_allocation;
+      } else {
+        int prev_required_pool_size = all_pools_sizes_[pool_info];
+        if (prev_required_pool_size < required_pool_size_for_allocation) {
+          all_pools_sizes_[pool_info] = required_pool_size_for_allocation;
+        }
+      }
+    }
+
+    for (const auto& kv : all_pools_sizes_) {
+      PoolInfo pi = kv.first;
+      int allocated_size = kv.second;
+      allocated_pool_ordering_.push_back(AllocatedPoolInfo(pi, allocated_size));
+    }
+    std::sort(allocated_pool_ordering_.begin(), allocated_pool_ordering_.end(),
+              [](const AllocatedPoolInfo& lhs, const AllocatedPoolInfo& rhs) {
+                if (lhs->pool_info->pool_name < rhs->pool_info->pool_name) {
+                  return true;
+                }
+                return false;
+              });
+  }
+  IRModule operator()();
+
+ private:
+  PrimExpr VisitExpr_(const CallNode* op) override;
+  Stmt VisitStmt_(const AllocateNode* op) override;
+  //  PrimExpr VisitExpr_(const VarNode* op) override;
+  PrimExpr VisitExpr_(const LoadNode* op) override;
+  Stmt VisitStmt_(const StoreNode* op) override;
+
+  /*! \brief This is a structure where the modified function
+   * signature is kept while body of the function is mutated
+   */
+  struct ScopeInfo {
+    Array<tir::Var> params;
+    Map<PoolInfo, tir::Var> pools_to_params;
+    Array<AllocatedPoolInfo> allocated_pool_params;
+    Map<tir::Var, Buffer> buffer_map;
+  };
+
+  /*! \brief The function scope information that are needed
+   * in the mutation of the function need to be stacked and
+   * popped when each function is entered/exited in the
+   * mutation process.
+   */
+  std::stack<ScopeInfo> scope_stack;
+  /*! \brief Each PrimFunc signature needs to be updated
+   * with pool variables. This is a helper function to
+   * capture the updated information to ScopeInfo object.
+   */
+  ScopeInfo UpdateFunctionScopeInfo(const PrimFunc& original_func);
+  /*! \brief This is a helper to create the PrimFunc with
+   * pool variables that calls the UpdateFunctionScopeInfo
+   * inside of it.
+   */
+  PrimFunc CreatePrimFuncWithPoolParams(const PrimFunc& original_primfunc);
+  /*! \brief This is a helper to append the pool args to
+   * the callsite of the function.
+   */
+  Array<PrimExpr> AppendPoolParamsToArgs(const Array<PrimExpr>& args);
+  /*! \brief Some arguments that used to be Allocate nodes
+   * should be replaced by Let nodes in the pass that loads
+   * the space from a pool variable.
+   */
+  Array<PrimExpr> ReplaceAllocateArgsWithLetArgs(const Array<PrimExpr>& args);
+
+  /*! \brief The tir::Var map to PoolInfo objects */
+  Map<tir::Var, PoolInfo> primfunc_args_to_pool_info_map_;
+  /*! \brief The buffer var map to their allocate nodes */
+  Map<tir::Var, tir::Stmt> allocate_var_to_stmt_map_;
+  /*! \brief The IRModule being constructed/mutated */
+  IRModule module_;
+  /*! \brief The input allocate node to PoolAllocation map */
+  Map<tir::Stmt, PoolAllocation> pool_allocations_;
+  /*! \brief The set of ordered pools to ensure an unique order of args for functions */
+  std::vector<AllocatedPoolInfo> allocated_pool_ordering_;
+  /*! \brief The storage of calculated pool size at init */
+  std::unordered_map<PoolInfo, int, ObjectPtrHash, ObjectPtrEqual> all_pools_sizes_;
+  /*! \brief After mutation, each allocate buffer is replaced with tir::Var that is let bounded
+   * to position from a pool as designated by a PoolAllocation
+   */
+  Map<tir::Var, tir::Var> allocate_buf_to_let_var_;
+  /*! \brief A counter to give references to pools a reproducible unique set of names */
+  int pool_var_count_ = 0;
+  /*! \brief This toggles to remove non tvmscript printable items for IRModule for unit tests */
+  bool emit_tvmscript_printable_ = false;
+  /*! \brief A counter to give references to pools a reproducible unique set of names */
+  std::unordered_set<PrimFunc, ObjectPtrHash, ObjectPtrEqual> visited_primfuncs;
+};
+
+PoolAllocationToOffsetConverter::ScopeInfo PoolAllocationToOffsetConverter::UpdateFunctionScopeInfo(
+    const PrimFunc& original_func) {
+  ScopeInfo si;
+  si.params = original_func->params;
+  si.buffer_map = original_func->buffer_map;
+  Map<tir::Var, PoolInfo> ret;
+  for (const AllocatedPoolInfo& allocated_pool_info : allocated_pool_ordering_) {
+    PoolInfo pool_info = allocated_pool_info->pool_info;
+    String pool_ref_name = pool_info->pool_name + "_" + std::to_string(pool_var_count_++);
+    String var_name = pool_ref_name + "_var";
+    DataType elem_dtype = DataType::UInt(8);
+    Var buffer_var(var_name, PointerType(PrimType(elem_dtype), "global"));
+    Var pool_var;
+    if (!emit_tvmscript_printable_) {
+      pool_var = Var(var_name, PointerType(PrimType(elem_dtype), "global"));
+    } else {
+      pool_var = Var(var_name, DataType::Handle(8));
+    }
+    si.params.push_back(pool_var);
+    si.pools_to_params.Set(pool_info, pool_var);
+    si.allocated_pool_params.push_back(AllocatedPoolInfo(
+        allocated_pool_info->pool_info, allocated_pool_info->allocated_size, pool_var));
+
+    int pool_size = all_pools_sizes_[pool_info];
+    String buffer_var_name = pool_ref_name + "_buffer_var";
+    si.buffer_map.Set(pool_var, Buffer(buffer_var, elem_dtype, {pool_size}, {1}, 1, buffer_var_name,
+                                       16, 1, BufferType::kDefault));
+  }
+  return si;
+}
+
+PrimFunc PoolAllocationToOffsetConverter::CreatePrimFuncWithPoolParams(
+    const PrimFunc& original_primfunc) {
+  // Only create the new function if it was not modified with pool params
+  if (visited_primfuncs.find(original_primfunc) == visited_primfuncs.end()) {
+    ScopeInfo si = UpdateFunctionScopeInfo(original_primfunc);
+    this->scope_stack.push(si);
+    Stmt new_body = this->VisitStmt(original_primfunc->body);
+    this->scope_stack.pop();
+    DictAttrs original_attrs = original_primfunc->attrs;
+    // We dont need attrs of PrimFunc that might include non printable attrs such as target
+    // for unit tests where emit_tvmscript_printable_ is to be used.
+    if (emit_tvmscript_printable_) {
+      original_attrs = DictAttrs();
+    }
+    PrimFunc ret =
+        PrimFunc(si.params, new_body, original_primfunc->ret_type, si.buffer_map, original_attrs);
+    if (!emit_tvmscript_printable_) {
+      return WithAttr(ret, tvm::attr::kPoolArgs, si.allocated_pool_params);
+    }
+    visited_primfuncs.insert(ret);
+    return ret;
+  }
+  return original_primfunc;
+}
+
+Array<PrimExpr> PoolAllocationToOffsetConverter::AppendPoolParamsToArgs(
+    const Array<PrimExpr>& args) {
+  Array<PrimExpr> new_args;
+  for (const auto& arg : args) {
+    new_args.push_back(VisitExpr(arg));
+  }
+  ScopeInfo top_scope = this->scope_stack.top();
+  for (const auto& pools_vars : top_scope.pools_to_params) {
+    tir::Var pool_var = pools_vars.second;
+    Buffer buffer_var = top_scope.buffer_map[pool_var];
+    new_args.push_back(buffer_var->data);
+  }
+  return new_args;
+}
+
+Array<PrimExpr> PoolAllocationToOffsetConverter::ReplaceAllocateArgsWithLetArgs(
+    const Array<PrimExpr>& args) {
+  Array<PrimExpr> ret;
+  for (const PrimExpr& arg : args) {
+    if (arg->IsInstance<VarNode>() &&
+        allocate_buf_to_let_var_.find(Downcast<Var>(arg)) != allocate_buf_to_let_var_.end()) {
+      ret.push_back(allocate_buf_to_let_var_[Downcast<Var>(arg)]);
+    } else {
+      ret.push_back(arg);
+    }
+  }
+  return ret;
+}
+
+PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const CallNode* op) {
+  if (op->op.same_as(builtin::call_extern()) || op->op.same_as(builtin::tvm_call_cpacked())) {
+    String func_name = Downcast<StringImm>(op->args[0])->value;
+    Array<PrimExpr> new_args;
+    if (module_->ContainGlobalVar(func_name)) {
+      GlobalVar gv = module_->GetGlobalVar(func_name);
+      PrimFunc func = Downcast<PrimFunc>(module_->Lookup(gv));
+      PrimFunc prim_func = CreatePrimFuncWithPoolParams(func);
+      module_->Update(gv, prim_func);
+      new_args = AppendPoolParamsToArgs(op->args);
+      new_args = ReplaceAllocateArgsWithLetArgs(new_args);
+    } else {
+      new_args = ReplaceAllocateArgsWithLetArgs(op->args);
+    }
+    return Call(op->dtype, op->op, new_args);
+  }
+  if (op->op->IsInstance<PrimFuncNode>()) {
+    PrimFunc func = Downcast<PrimFunc>(op->op);
+    PrimFunc prim_func = CreatePrimFuncWithPoolParams(func);
+    Array<PrimExpr> new_args = AppendPoolParamsToArgs(op->args);
+    new_args = AppendPoolParamsToArgs(new_args);
+    new_args = ReplaceAllocateArgsWithLetArgs(new_args);
+    return Call(op->dtype, prim_func, new_args);
+  }
+  return StmtExprMutator::VisitExpr_(op);
+}
+
+Stmt PoolAllocationToOffsetConverter::VisitStmt_(const AllocateNode* op) {
+  if (pool_allocations_.count(GetRef<Allocate>(op))) {
+    ScopeInfo scope_info = scope_stack.top();
+    PoolAllocation pool_allocation = pool_allocations_[GetRef<Allocate>(op)];
+    Var param = scope_info.pools_to_params[pool_allocation->pool_info];
+    Buffer buffer_var = scope_info.buffer_map[param];
+    ICHECK(pool_allocation->byte_offset < all_pools_sizes_[pool_allocation->pool_info]);

Review comment:
       Perhaps include a message to state why this condition failed

##########
File path: src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc
##########
@@ -0,0 +1,351 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/analysis/usmp/transform/convert_pool_allocations_to_offsets.cc
+ * \brief This pass would convert the pool allocations to offsets from pools
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+#include <tvm/tir/usmp/utils.h>
+
+#include <stack>
+
+namespace tvm {
+namespace tir {
+namespace usmp {
+
+/*!
+ * \brief The StmtExpr mutator class to replace allocate nodes
+ * with offsets within memory pools
+ *
+ * This mutator class with add Pool variables recursively to every PrimFunc
+ * starting from the main PrimFunc. For all allocate nodes, that have been
+ * memory planned, will be mutated into an offset using a Let binding.
+ */
+class PoolAllocationToOffsetConverter : public StmtExprMutator {
+ public:
+  PoolAllocationToOffsetConverter(const IRModule& module,
+                                  const Map<tir::Stmt, PoolAllocation>& pool_allocations,
+                                  bool emit_tvmscript_printable = false)
+      : pool_allocations_(pool_allocations), emit_tvmscript_printable_(emit_tvmscript_printable) {
+    module_ = module->ShallowCopy();
+    for (const auto& kv : pool_allocations) {
+      // TODO(@manupa-arm): add AllocateConstNode when it is available
+      ICHECK(kv.first->IsInstance<AllocateNode>());
+      Allocate allocate_node = Downcast<Allocate>(kv.first);
+      PoolAllocation pool_allocation = kv.second;
+      PoolInfo pool_info = pool_allocation->pool_info;
+      int byte_pool_offset = pool_allocation->byte_offset->value;
+      int required_pool_size_for_allocation =
+          byte_pool_offset + CalculateExtentsSize(allocate_node.operator->());
+      if (all_pools_sizes_.find(pool_info) == all_pools_sizes_.end()) {
+        all_pools_sizes_[pool_info] = required_pool_size_for_allocation;
+      } else {
+        int prev_required_pool_size = all_pools_sizes_[pool_info];
+        if (prev_required_pool_size < required_pool_size_for_allocation) {
+          all_pools_sizes_[pool_info] = required_pool_size_for_allocation;
+        }
+      }
+    }
+
+    for (const auto& kv : all_pools_sizes_) {
+      PoolInfo pi = kv.first;
+      int allocated_size = kv.second;
+      allocated_pool_ordering_.push_back(AllocatedPoolInfo(pi, allocated_size));
+    }
+    std::sort(allocated_pool_ordering_.begin(), allocated_pool_ordering_.end(),
+              [](const AllocatedPoolInfo& lhs, const AllocatedPoolInfo& rhs) {
+                if (lhs->pool_info->pool_name < rhs->pool_info->pool_name) {
+                  return true;
+                }
+                return false;
+              });
+  }
+  IRModule operator()();
+
+ private:
+  PrimExpr VisitExpr_(const CallNode* op) override;
+  Stmt VisitStmt_(const AllocateNode* op) override;
+  //  PrimExpr VisitExpr_(const VarNode* op) override;

Review comment:
       Needs removing?

##########
File path: src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc
##########
@@ -0,0 +1,351 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/analysis/usmp/transform/convert_pool_allocations_to_offsets.cc
+ * \brief This pass would convert the pool allocations to offsets from pools
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+#include <tvm/tir/usmp/utils.h>
+
+#include <stack>
+
+namespace tvm {
+namespace tir {
+namespace usmp {
+
+/*!
+ * \brief The StmtExpr mutator class to replace allocate nodes
+ * with offsets within memory pools
+ *
+ * This mutator class with add Pool variables recursively to every PrimFunc

Review comment:
       ```suggestion
    * This mutator class will add Pool variables recursively to every PrimFunc
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org