You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/12/09 16:14:56 UTC

[GitHub] [tvm] junrushao1994 commented on a change in pull request #9689: [TIR] Allow memory (aka storage) scopes to be retrieved/applied to PrimFuncs

junrushao1994 commented on a change in pull request #9689:
URL: https://github.com/apache/tvm/pull/9689#discussion_r765920384



##########
File path: include/tvm/tir/analysis.h
##########
@@ -26,12 +26,14 @@
 
 #include <tvm/ir/module.h>
 #include <tvm/ir/transform.h>
+#include <tvm/target/se_scope.h>

Review comment:
       nit: probably we don't need two two includes?

##########
File path: src/tir/analysis/device_constraint_utils.h
##########
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/analysis/device_constraint_utils.cc
+ * \brief Utilities for extracting and applying device-related constraints to \p PrimFunc
+ * parameters.
+ *
+ * These utilities are used by the \p PlanDevices pass to extract memory (aka 'storage') scope
+ * information from \p PrimFuncs and convert them back into \p SEScope form w.r.t. the original
+ * Relay type of the \p PrimFunc (ie before flattening of tuple arguments/results and conversion
+ * to destination-passing style aka DPS).
+ *
+ * A utility is also supplied to go the other way: impose memory scopes on \p PrimFunc parameters.
+ * However that's still in EXPERIMENTAL form.
+ *
+ * We may extend these utilities to also gather/apply layout information should we add that to
+ * \p SEScope.
+ */
+
+#ifndef TVM_TIR_ANALYSIS_DEVICE_CONSTRAINT_UTILS_H_
+#define TVM_TIR_ANALYSIS_DEVICE_CONSTRAINT_UTILS_H_
+
+#include <tvm/target/se_scope.h>
+#include <tvm/tir/function.h>
+
+namespace tvm {
+namespace tir {
+
+/*
+ * A Relay Function with type:
+ * \code
+ *   fn((Tensor[...], Tensor[...]), Tensor[...]) -> (Tensor[...], Tensor[...])
+ *       ^            ^             ^                ^            ^
+ *       a            b             c                d            e
+ * \endcode
+ * will be represented by a TIR PrimFunc in flattened and DPS form with at least 5 argument a..e.
+ * Each such PrimFunc argument will have a type annotation for a PointerType to the underlying
+ * tensor's buffer. The PrimFunc may have additional non-pointer arguments, for example to represent

Review comment:
       > The PrimFunc may have additional non-pointer arguments
   
   Yeah that's correct. Another example is that PrimFunc takes scalars as inputs too, but there is no correspondence in Relay either

##########
File path: src/tir/analysis/device_constraint_utils.cc
##########
@@ -0,0 +1,523 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/analysis/apply_device_constraints.cc
+ * \brief Applies device-related constraints to \p PrimFunc parameters.
+ *
+ * This is used by the \p PlanDevices pass to flow device-constraints *into* \p PrimFuncs.
+ *
+ * Currently only applies memory scope constraints into \p Buffer data pointer
+ * storage scopes. Aliased ('matched') buffers take on any scope introduced on
+ * the buffer they alias. However currently does not attempt to flow constraints into
+ * allocated buffers.
+ */
+
+#include "./device_constraint_utils.h"
+
+#include <tvm/relay/attrs/memory.h>
+#include <tvm/target/se_scope.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+
+namespace tvm {
+namespace tir {
+namespace {
+
+/*!
+ * \brief Returns the \p PointerTypeNode for \p buffer, or nullptr if \p buffer does not describe a
+ * pointer.
+ */
+const PointerTypeNode* PointerInBuffer(const tir::Buffer& buffer) {
+  return buffer->data->type_annotation.defined()
+             ? buffer->data->type_annotation.as<PointerTypeNode>()
+             : nullptr;
+}
+
+/*!
+ * \brief Returns the parameter variable and corresponding buffer at or after \p
+ * *current_primfunc_param_index in \p prim_func. Will skip over any non-pointer parameters. This
+ * can be used to find the parameter matching a tensor type in a flattened Relay function parameter
+ * or result.
+ */
+std::pair<tir::Var, tir::Buffer> FindPointerParam(const tir::PrimFunc& prim_func,
+                                                  size_t* current_primfunc_param_index) {
+  while (true) {
+    ICHECK_LT(*current_primfunc_param_index, prim_func->params.size());
+    const tir::Var& param = prim_func->params[*current_primfunc_param_index];
+    auto itr = prim_func->buffer_map.find(param);
+    if (itr == prim_func->buffer_map.end()) {
+      VLOG(2) << "no buffer map entry for '" << param->name_hint << "'";
+      ++*current_primfunc_param_index;
+      continue;
+    }
+    const auto* pointer_type_node = PointerInBuffer((*itr).second);
+    if (pointer_type_node == nullptr) {
+      VLOG(2) << "not a pointer type for '" << param->name_hint << "'";
+      ++*current_primfunc_param_index;
+      continue;
+    }
+    VLOG(2) << "using PrimFunc param '" << param->name_hint << "'";
+    return *itr;
+  }
+}
+
+/*!
+ * \brief Check fails if any parameter at or after \p *current_primfunc_param_index in \p prim_func
+ * is for a pointer type. This can be used to check all \p prim_func parameters have been accounted
+ * for when using \p FindPointerParam above.
+ */
+void CheckNoRemainingPointerParams(const tir::PrimFunc& prim_func,
+                                   size_t* current_primfunc_param_index) {
+  while (*current_primfunc_param_index < prim_func->params.size()) {
+    const tir::Var& param = prim_func->params[*current_primfunc_param_index];
+    auto itr = prim_func->buffer_map.find(param);
+    if (itr == prim_func->buffer_map.end()) {
+      VLOG(1) << "no buffer map entry for '" << param->name_hint << "'";
+      ++*current_primfunc_param_index;
+      continue;
+    }
+    const auto* pointer_type_node = PointerInBuffer((*itr).second);
+    ICHECK(pointer_type_node == nullptr);
+    ++*current_primfunc_param_index;
+  }
+}
+
+/*!
+ * \brief Returns the (consistent) constraint to use for a Relay parameter of \p type,
+ * using \p prim_func parameters at or after \p *current_primfunc_param_index. Currently
+ * only memory scope is extracted. Fails if constraints are not consistent, ie \p type is a tuple
+ * type and the \p prim_func is attempting to map different fields of that tuple to different memory
+ * scopes. Returns the fully unconstrained \p SEScope if no memory scopes constraints arise from
+ * the \p prim_func, ie all storage scope strings in pointer types are empty.
+ */
+SEScope ConsistentParamConstraint(const tir::PrimFunc& prim_func, const Type& type,
+                                  size_t* current_primfunc_param_index) {
+  std::string memory_scope;  // default empty => no constraint
+  for (size_t i = 0; i < relay::FlattenTupleType(type).size(); ++i) {
+    std::pair<tir::Var, tir::Buffer> kv = FindPointerParam(prim_func, current_primfunc_param_index);
+    const tir::Buffer& buffer = kv.second;
+    const auto* pointer_type_node = buffer->data->type_annotation.as<PointerTypeNode>();
+    const MemoryScope& buffer_memory_scope = pointer_type_node->storage_scope;
+    if (memory_scope.empty()) {
+      memory_scope = buffer_memory_scope;
+    } else if (buffer_memory_scope.empty()) {
+      // No constraint.
+    } else {
+      // Tuples must be homogenous on their SEScope and thus memory scope.
+      ICHECK_EQ(buffer_memory_scope, memory_scope);
+    }
+    ++*current_primfunc_param_index;
+  }
+  return SEScope::ForMemoryScope(memory_scope);
+}
+
+/*!
+ * \brief Insert into param_constraints an entry for each parameter of \p prim_func starting from
+ * \p *current_primfunc_param_index for the flattened form of a Rleay parameters of \p type. Each
+ * entry maps to \p se_scope.
+ */
+void InsertParamConstraints(const tir::PrimFunc& prim_func, const Type& type,
+                            const SEScope& se_scope, size_t* current_primfunc_param_index,
+                            std::unordered_map<const tir::VarNode*, SEScope>* param_constraints) {
+  for (size_t i = 0; i < relay::FlattenTupleType(type).size(); ++i) {
+    std::pair<tir::Var, tir::Buffer> kv = FindPointerParam(prim_func, current_primfunc_param_index);
+    param_constraints->emplace(kv.first.get(), se_scope);
+    ++*current_primfunc_param_index;
+  }
+}
+
+/*!
+ * \brief Apply the memory scope constraints to the \p Buffers and data \p Vars of a \p PrimFunc.
+ *
+ * All definitional occurrences of buffer Vars are rewritten to capture memory scopes in their
+ * PointerTypes:
+ *  - Buffer::data (if the buffer itself is a definitional occurrence)
+ *  - AllocateNode::buffer_var
+ *  - FUTURE: LetStmtNode::var if aliasing a buffer data var.
+ *
+ * All referential occurrences of buffer Vars are replaced with their new definitions:
+ *  - LoadNode::buffer_var
+ *  - StoreNode::buffer_var
+ *
+ * Similarly all definitional occurrences of Buffers are rewritten to account for any new memory
+ * scopes:
+ *  - PrimFuncNode::buffer_map keys.
+ *  - BlockNode::match_buffers.buffer
+ *  - FUTURE: BlockNode::alloc_buffers?
+ *
+ * And all referential occurrences of Buffers are replaced with their new definitions:
+ *  - BufferLoadNode::buffer
+ *  - BufferStoreNode::buffer
+ *  - BufferRealizeNode::buffer
+ *  - PrefetchNode::buffer
+ *  - BufferRegionNode:buffer
+ *  - BlockNode.match_buffers.source.buffer
+ *  - BlockNode::{reads, writes}.buffer
+ *
+ * CAUTION: We assume strict sharing of Buffer objects and do not attempt to rewrite the bodies
+ * of referential buffers.
+ *
+ * CAUTION: EXPERIMENTAL: We don't yet account for all buffers and pointer types.
+ */
+class ApplyDeviceConstraintsMutator : public StmtExprMutator {
+ public:
+  ApplyDeviceConstraintsMutator() = default;
+
+  /*!
+   * \brief Returns \p prim_func written to capture the memory scope constraints in \p
+   * param_constraints for each pointer \p prim_func parameter. Returns \p prim_func unchanged if no
+   * memory scopes needed to change.
+   */
+  PrimFunc Rewrite(const PrimFunc& prim_func, const FuncType& relay_func_type,
+                   const Array<SEScope>& arg_and_result_se_scopes) {
+    size_t current_primfunc_param_index = 0;
+    std::unordered_map<const tir::VarNode*, SEScope> param_constraints;
+
+    // For each Relay function parameter...
+    for (size_t i = 0; i < relay_func_type->arg_types.size(); ++i) {
+      const Type& param_type = relay_func_type->arg_types[i];
+      const SEScope& param_se_scope = arg_and_result_se_scopes[i];
+      InsertParamConstraints(prim_func, param_type, param_se_scope, &current_primfunc_param_index,
+                             &param_constraints);
+    }
+
+    // For the Relay function result...
+    const Type& ret_type = relay_func_type->ret_type;
+    const SEScope& ret_se_scope = arg_and_result_se_scopes.back();
+    InsertParamConstraints(prim_func, ret_type, ret_se_scope, &current_primfunc_param_index,
+                           &param_constraints);
+
+    // Make sure we accounted for all prim_func parameters.
+    CheckNoRemainingPointerParams(prim_func, &current_primfunc_param_index);
+
+    // Start with a copy of the current prim_func buffer map.
+    Map<Var, Buffer> new_buffer_map(prim_func->buffer_map.begin(), prim_func->buffer_map.end());
+    bool any_change = false;
+
+    // For each constrained parameter...
+    for (const auto& kv : param_constraints) {
+      const tir::Var param = GetRef<tir::Var>(kv.first);
+      const SEScope& se_scope = kv.second;
+      const tir::Buffer& buffer = prim_func->buffer_map[param];
+      // Rewrite the buffer to account for constraint.
+      const Buffer new_buffer = RewriteBuffer(buffer, se_scope);
+      if (!new_buffer.same_as(buffer)) {
+        any_change = true;
+      }
+      new_buffer_map.Set(param, new_buffer);
+    }
+    // Make sure we have accounted for all prim_func parameters.
+    CheckNoRemainingPointerParams(prim_func, &current_primfunc_param_index);
+
+    // Apply data variable and buffer substitutions to the prim_func body. These will have been
+    // accumulated from processing the parameters above.
+    Stmt new_body = VisitStmt(prim_func->body);
+    if (!new_body.same_as(prim_func->body)) {
+      any_change = true;
+    }
+
+    // We are done with the substitutions.
+    var_subst_.clear();
+    buffer_subst_.clear();
+
+    if (any_change) {
+      return PrimFunc(prim_func->params, std::move(new_body), prim_func->ret_type,
+                      std::move(new_buffer_map), prim_func->attrs, prim_func->span);
+    } else {
+      return prim_func;
+    }
+  }
+
+ private:
+  PrimExpr VisitExpr_(const VarNode* var_node) final { return Subst(var_node); }
+
+  PrimExpr VisitExpr_(const LoadNode* load_node) final {
+    Load new_load = Downcast<Load>(StmtExprMutator::VisitExpr_(load_node));
+    Var new_buffer_var = Subst(new_load->buffer_var.get());
+    if (!new_buffer_var.same_as(new_load->buffer_var)) {
+      return Load(load_node->dtype, new_buffer_var, load_node->index, load_node->predicate);
+    }
+    return new_load;
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* buffer_load_node) final {
+    BufferLoad new_buffer_load =
+        Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(buffer_load_node));
+    Buffer new_buffer = Subst(new_buffer_load->buffer.get());
+    if (!new_buffer.same_as(new_buffer_load->buffer)) {
+      return BufferLoad(new_buffer, new_buffer_load->indices, new_buffer_load->span);
+    }
+    return new_buffer_load;
+  }
+
+  Stmt VisitStmt_(const LetStmtNode* let_stmt_node) final {
+    // TODO(mbs): If the let-bound var is aliasing an existing buffer data var we need to
+    // rewrite it.
+    return StmtExprMutator::VisitStmt_(let_stmt_node);
+  }
+
+  Stmt VisitStmt_(const AttrStmtNode* attr_stmt_node) final {
+    AttrStmt new_attr_stmt = Downcast<AttrStmt>(StmtExprMutator::VisitStmt_(attr_stmt_node));
+    // remap node if a var
+    if (const auto* var_node = new_attr_stmt->node.as<VarNode>()) {
+      Var new_var = Subst(var_node);
+      if (!new_var.same_as(new_attr_stmt->node)) {
+        return AttrStmt(new_var, new_attr_stmt->attr_key, new_attr_stmt->value,
+                        new_attr_stmt->body);
+      }
+    }
+    return new_attr_stmt;
+  }
+
+  // ForNode default ok since loop_var never of PointerType
+
+  // WhileNode default ok
+
+  Stmt VisitStmt_(const AllocateNode* allocate_node) final {
+    // TODO(mbs): What memory scope should we assign to the new pointer?
+    return StmtExprMutator::VisitStmt_(allocate_node);
+  }
+
+  Stmt VisitStmt_(const StoreNode* store_node) final {
+    Store new_store = Downcast<Store>(StmtExprMutator::VisitStmt_(store_node));
+    Var new_buffer_var = Subst(new_store->buffer_var.get());
+    if (!new_buffer_var.same_as(new_store->buffer_var)) {
+      Store(new_buffer_var, new_store->value, new_store->index, new_store->predicate);
+    }
+    return new_store;
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* buffer_store_node) final {
+    BufferStore new_buffer_store =
+        Downcast<BufferStore>(StmtExprMutator::VisitStmt_(buffer_store_node));
+    Buffer new_buffer = Subst(new_buffer_store->buffer.get());
+    if (!new_buffer.same_as(new_buffer_store->buffer)) {
+      return BufferStore(new_buffer, new_buffer_store->value, new_buffer_store->indices,
+                         new_buffer_store->span);
+    }
+    return new_buffer_store;
+  }
+
+  Stmt VisitStmt_(const BufferRealizeNode* buffer_realize_node) final {
+    BufferRealize new_buffer_realize =
+        Downcast<BufferRealize>(StmtExprMutator::VisitStmt_(buffer_realize_node));
+    Buffer new_buffer = Subst(new_buffer_realize->buffer.get());
+    if (!new_buffer.same_as(new_buffer_realize->buffer)) {
+      return BufferRealize(new_buffer, new_buffer_realize->bounds, new_buffer_realize->condition,
+                           new_buffer_realize->body, new_buffer_realize->span);
+    }
+    return new_buffer_realize;
+  }
+
+  // IfThenElseNode default ok
+  // AssertStmtNode default ok
+  // ProducerStoreNode default ok (though does not visit producer)
+  // ProducerRealizeNode default ok (though does not visit producer)
+
+  Stmt VisitStmt_(const PrefetchNode* prefetch_node) final {
+    Prefetch new_prefetch = Downcast<Prefetch>(StmtExprMutator::VisitStmt_(prefetch_node));
+    Buffer new_buffer = Subst(new_prefetch->buffer.get());
+    if (!new_buffer.same_as(new_prefetch->buffer)) {
+      return Prefetch(new_buffer, prefetch_node->bounds, prefetch_node->span);
+    }
+    return new_prefetch;
+  }
+
+  // SeqStmtNode default ok
+  // EvaluateNode default ok
+
+  BufferRegion VisitItem(const BufferRegionNode* buffer_region_node) {
+    Buffer new_buffer = Subst(buffer_region_node->buffer.get());
+    if (!new_buffer.same_as(buffer_region_node->buffer)) {
+      return BufferRegion(new_buffer, buffer_region_node->region);
+    }
+    return GetRef<BufferRegion>(buffer_region_node);
+  }
+
+  MatchBufferRegion VisitItem(const MatchBufferRegionNode* match_buffer_region_node) {
+    // The source field has a referential occurrence of the  buffer. Apply the buffer substitution
+    // to that.
+    BufferRegion new_source = VisitItem(match_buffer_region_node->source.get());
+    // The buffer field however is a definitional occurrence, aliased on top of the source.
+    // Transfer any memory scope from the source to the destination.
+    Optional<SEScope> opt_se_scope = GetBufferConstraint(new_source->buffer);
+    tir::Buffer new_buffer;
+    if (opt_se_scope.defined()) {
+      new_buffer = RewriteBuffer(match_buffer_region_node->buffer, opt_se_scope.value());
+    } else {
+      new_buffer = match_buffer_region_node->buffer;
+    }
+    if (!new_buffer.same_as(match_buffer_region_node->buffer) ||
+        !new_source.same_as(match_buffer_region_node->source)) {
+      return MatchBufferRegion(new_buffer, new_source);
+    }
+    return GetRef<MatchBufferRegion>(match_buffer_region_node);
+  }
+
+  template <typename T>
+  Array<T> VisitItems(Array<T> items) {

Review comment:
       What's the diff between this method and Array's MutateByApply API?

##########
File path: python/tvm/tir/analysis/analysis.py
##########
@@ -196,3 +197,70 @@ def detect_buffer_access_lca(func: PrimFunc) -> Dict[Buffer, Stmt]:
         Map from buffer to the LCA of all access to it.
     """
     return _ffi_api.detect_buffer_access_lca(func)  # type: ignore # pylint: disable=no-member
+
+
+# NOTE: relay_func_type in the following two functions should be relay.FuncType however that would
+# introduce a cycling dependency. We make do with Object.
+
+
+def get_prim_func_arg_and_result_memory_constraints(
+    func: PrimFunc, relay_func_type: Object
+) -> List[AnyStr]:

Review comment:
       QQ: Why did we use List[AnyStr] instead of List[str]




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org