You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lu...@apache.org on 2023/09/28 13:54:33 UTC

[tvm] branch unity updated: [Unity] Implement relax.transform.KillAfterLastUse (#15810)

This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new aa4587feb5 [Unity] Implement relax.transform.KillAfterLastUse (#15810)
aa4587feb5 is described below

commit aa4587feb5103927d95e5e931149debd0a0aeafc
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Thu Sep 28 08:54:22 2023 -0500

    [Unity] Implement relax.transform.KillAfterLastUse (#15810)
    
    * [Unity][Util] Expose CanonicalizeBindings as internal utility
    
    * [Unity] Implement relax.transform.KillAfterLastUse
    
    Prior to this commit, intermediate objects produced while executing a
    Relax function would persist until the end of the Relax function.
    While re-use of static allocations is handled by the
    `StaticPlanBlockMemory` transform, re-use of dynamic allocations is
    handled by the `relax_vm::PooledAllocator`.  For large Relax functions
    representing end-to-end model execution, releasing memory from the VM
    registers to the `relax_vm::PooledAllocator` at the end of the
    function call may be insufficient.
    
    This commit introduces a new pass, `relax.transform.KillAfterLastUse`,
    which identifies the last usage of each Relax variable and inserts a
    `relax.memory.kill_tensor`, `relax.memory.kill_storage`, or
    `relax.vm.kill_object` call depending on the object type.  This
    insertion is suppressed if a Relax variables is already killed, such
    as static allocations and tensors tracked by `StaticPlanBlockMemory`.
    
    * Avoid calling R.vm.kill_object on objects not in registers
---
 python/tvm/relax/transform/transform.py        |  10 +
 python/tvm/relax/vm_build.py                   |   1 +
 src/relax/transform/kill_after_last_use.cc     | 289 +++++++++++++++++++++++++
 src/relax/transform/utils.h                    |  13 ++
 tests/python/relax/test_kill_after_last_use.py |  55 +++++
 5 files changed, 368 insertions(+)

diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py
index 13874aa044..2a06d5098e 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -368,6 +368,16 @@ def StaticPlanBlockMemory() -> tvm.ir.transform.Pass:
     return _ffi_api.StaticPlanBlockMemory()  # type: ignore
 
 
+def KillAfterLastUse() -> tvm.ir.transform.Pass:
+    """Drop all tensor/storage objects after last use
+
+    Returns
+    -------
+    ret : tvm.ir.transform.Pass
+    """
+    return _ffi_api.KillAfterLastUse()  # type: ignore
+
+
 def VMBuiltinLower() -> tvm.ir.transform.Pass:
     """Lowering generic intrinsic to VM intrinsics.
 
diff --git a/python/tvm/relax/vm_build.py b/python/tvm/relax/vm_build.py
index d8679522db..142da5c451 100644
--- a/python/tvm/relax/vm_build.py
+++ b/python/tvm/relax/vm_build.py
@@ -310,6 +310,7 @@ def build(
     passes.append(relax.transform.RemovePurityChecking())
     passes.append(relax.transform.CallTIRRewrite())
     passes.append(relax.transform.StaticPlanBlockMemory())
+    passes.append(relax.transform.KillAfterLastUse())
 
     if tvm.transform.PassContext.current().config.get("relax.backend.use_cuda_graph", False):
         passes.append(relax.transform.RewriteCUDAGraph())
diff --git a/src/relax/transform/kill_after_last_use.cc b/src/relax/transform/kill_after_last_use.cc
new file mode 100644
index 0000000000..0f28c6c2b9
--- /dev/null
+++ b/src/relax/transform/kill_after_last_use.cc
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/relax/transform/kill_after_last_use.cc
+ * \brief Kill storage/tensor objects after last use, if not already killed
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/nested_msg.h>
+#include <tvm/relax/transform.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <map>
+#include <set>
+#include <vector>
+
+#include "utils.h"
+
+namespace tvm {
+namespace relax {
+
+class UnusedTrivialBindingRemover : public ExprMutator {
+ public:
+  static Expr Apply(Expr expr) {
+    struct UsedCollector : ExprVisitor {
+      void VisitExpr_(const VarNode* val) override { used.insert(val); }
+      void VisitExpr_(const DataflowVarNode* val) override {
+        VisitExpr_(static_cast<const VarNode*>(val));
+      }
+
+      void VisitBinding_(const VarBindingNode* binding, const VarNode* val) override {
+        has_trivial_binding.insert(binding->var.get());
+        ExprVisitor::VisitBinding_(binding, val);
+      }
+      void VisitBinding_(const VarBindingNode* binding, const DataflowVarNode* val) override {
+        VisitBinding_(binding, static_cast<const VarNode*>(val));
+      }
+
+      std::unordered_set<const VarNode*> used;
+      std::unordered_set<const VarNode*> has_trivial_binding;
+    };
+
+    UsedCollector collector;
+    collector(expr);
+
+    auto to_remove = std::move(collector.has_trivial_binding);
+    for (const auto& used : collector.used) {
+      to_remove.erase(used);
+    }
+
+    UnusedTrivialBindingRemover remover(to_remove);
+    return remover(expr);
+  }
+
+ private:
+  explicit UnusedTrivialBindingRemover(std::unordered_set<const VarNode*> to_remove)
+      : to_remove_(std::move(to_remove)) {}
+
+  void VisitBinding(const Binding& binding) override {
+    if (!to_remove_.count(binding->var.get())) {
+      ExprMutator::VisitBinding(binding);
+    }
+  }
+
+  std::unordered_set<const VarNode*> to_remove_;
+};
+
+class CollectLastUsage : public ExprVisitor {
+ public:
+  struct LastUsage {
+    std::vector<const VarNode*> tensors;
+    std::vector<const VarNode*> storage;
+    std::vector<const VarNode*> objects;
+  };
+  using Result = std::unordered_map<const VarNode*, LastUsage>;
+
+  static Result Collect(const Expr& expr) {
+    CollectLastUsage visitor;
+    visitor(expr);
+
+    Result output;
+    for (const auto* var : visitor.binding_order_) {
+      if (auto it = visitor.last_usage_of_.find(var); it != visitor.last_usage_of_.end()) {
+        const auto* last_usage_point = it->second;
+        bool is_output = last_usage_point == nullptr;
+        bool already_killed = visitor.killed_objects_.count(var);
+
+        // Currently, the VM requires that objects to be killed
+        // objects only exist in VM registers.  This requires
+        // KillAfterLastUse to have more knowledge about the VM
+        // implementation than should exist at this stage of lowering.
+        // In the future, this may be handled more easily at the
+        // CodeGenVM level.
+        bool stored_in_vm_register =
+            !(visitor.constant_tensors_.count(var) || var->struct_info_.as<FuncStructInfoNode>() ||
+              var->struct_info_.as<ShapeStructInfoNode>() ||
+              var->struct_info_.as<PrimStructInfoNode>());
+
+        if (!is_output && !already_killed) {
+          if (visitor.storage_objects_.count(var)) {
+            output[last_usage_point].storage.push_back(var);
+          } else if (var->struct_info_.as<TensorStructInfoNode>() && stored_in_vm_register) {
+            output[last_usage_point].tensors.push_back(var);
+          } else if (stored_in_vm_register) {
+            output[last_usage_point].objects.push_back(var);
+          }
+        }
+      }
+    }
+
+    return output;
+  }
+
+  void VisitBinding(const Binding& binding) override {
+    auto cache = current_binding_;
+    current_binding_ = binding->var.get();
+    binding_order_.push_back(current_binding_);
+    ExprVisitor::VisitBinding(binding);
+    current_binding_ = cache;
+  }
+
+  void VisitExpr_(const VarNode* op) override {
+    ExprVisitor::VisitExpr_(op);
+    // Overwrite any previous usage, such that after the visitor
+    // completes, last_usage_of_ contains the last usage point.  If
+    // this occurs in an output, then current_binding_ will be
+    // nullptr.
+    last_usage_of_[UnwrapTrivialBindings(op)] = current_binding_;
+  }
+
+  void VisitBinding_(const VarBindingNode* binding, const CallNode* val) override {
+    static const Op& vm_alloc_storage = Op::Get("relax.vm.alloc_storage");
+    static const Op& mem_alloc_storage = Op::Get("relax.memory.alloc_storage");
+
+    static const Op& mem_kill_tensor = Op::Get("relax.memory.kill_tensor");
+    static const Op& mem_kill_storage = Op::Get("relax.memory.kill_storage");
+    static const Op& vm_kill_object = Op::Get("relax.vm.kill_object");
+
+    if (val->op.same_as(vm_alloc_storage) || val->op.same_as(mem_alloc_storage)) {
+      storage_objects_.insert(binding->var.get());
+    } else if (val->op.same_as(mem_kill_tensor) || val->op.same_as(mem_kill_storage) ||
+               val->op.same_as(vm_kill_object)) {
+      CHECK_EQ(val->args.size(), 1)
+          << "Operator " << val->op << " should have one argument, "
+          << "but instead found " << val->args.size() << " arguments: " << val->args;
+      auto killed_object = val->args[0].as<VarNode>();
+      ICHECK(killed_object) << "Internal error: non-normalized expression " << GetRef<Call>(val);
+      killed_objects_.insert(UnwrapTrivialBindings(killed_object));
+    } else {
+      // Only recursively visit if it isn't one of the special cases.
+      ExprVisitor::VisitBinding_(binding, val);
+    }
+  }
+
+  void VisitBinding_(const VarBindingNode* binding, const VarNode* val) override {
+    // Because the VM re-uses the same register for variable
+    // re-binding, we need to de-duplicate across trivial bindings in
+    // order to avoid calling `vm.kill_object` multiple times on the
+    // same register.  In the future, this can be simplified by
+    // replacing the de-duplication in CodeGenVM with a call to
+    // CanonicalizeBindings.
+    trivial_bindings_.insert({binding->var.get(), UnwrapTrivialBindings(val)});
+
+    // Do not call ExprVisitor::VisitBinding_ here, as the trivial
+    // rebinding should not be treated as a point of use.
+  }
+
+  void VisitBinding_(const VarBindingNode* binding, const ConstantNode* val) override {
+    constant_tensors_.insert(binding->var.get());
+  }
+
+ private:
+  const VarNode* UnwrapTrivialBindings(const VarNode* var) const {
+    while (true) {
+      if (auto it = trivial_bindings_.find(var); it != trivial_bindings_.end()) {
+        var = it->second;
+      } else {
+        return var;
+      }
+    }
+  }
+
+  // The current binding being visited, or nullptr if no binding is
+  // being visited.
+  const VarNode* current_binding_{nullptr};
+
+  // Order of bindings, to ensure consistent order of destruction, in
+  // case a Binding is the last usage for more than one variable.
+  std::vector<const VarNode*> binding_order_;
+
+  // Map from a variable to the last variable binding that makes use
+  // of it.
+  std::unordered_map<const VarNode*, const VarNode*> last_usage_of_;
+
+  // Storage objects, eligible for R.vm.kill_object.  This cannot be
+  // determined solely from the StructInfo, because the
+  // `R.*.alloc_storage` operators return ObjectStructInfo
+  std::unordered_set<const VarNode*> storage_objects_;
+
+  // Constants, which do not have a VM register, and may *not* have
+  // R.builtin.kill_tensor called on them.
+  std::unordered_set<const VarNode*> constant_tensors_;
+
+  // Set of objects that already have a call node to kill them.  Should not have a duplicate
+  std::unordered_set<const VarNode*> killed_objects_;
+
+  // Trivial var-to-var bindings.
+  std::unordered_map<const VarNode*, const VarNode*> trivial_bindings_;
+};
+
+class KillInserter : public ExprMutator {
+ private:
+  Expr VisitExpr_(const FunctionNode* op) override {
+    last_usage_ = CollectLastUsage::Collect(GetRef<Expr>(op));
+    auto mutated = ExprMutator::VisitExpr_(op);
+    last_usage_.clear();
+    return mutated;
+  }
+
+  Expr VisitExpr_(const SeqExprNode* op) override {
+    last_usage_ = CollectLastUsage::Collect(GetRef<Expr>(op));
+    auto mutated = ExprMutator::VisitExpr_(op);
+    last_usage_.clear();
+    return mutated;
+  }
+
+  void VisitBinding(const Binding& binding) override {
+    ExprMutator::VisitBinding(binding);
+    if (auto it = last_usage_.find(binding->var.get()); it != last_usage_.end()) {
+      static const Op& mem_kill_tensor = Op::Get("relax.memory.kill_tensor");
+      for (const auto& tensor_obj : it->second.tensors) {
+        builder_->Emit(Call(mem_kill_tensor, {GetRef<Expr>(tensor_obj)}), /*name_hint=*/"_");
+      }
+
+      static const Op& mem_kill_storage = Op::Get("relax.memory.kill_storage");
+      for (const VarNode* storage_obj : it->second.storage) {
+        builder_->Emit(Call(mem_kill_storage, {GetRef<Expr>(storage_obj)}), /*name_hint=*/"_");
+      }
+
+      static const Op& vm_kill_object = Op::Get("relax.vm.kill_object");
+      for (const VarNode* obj : it->second.objects) {
+        builder_->Emit(Call(vm_kill_object, {GetRef<Expr>(obj)}), /*name_hint=*/"_");
+      }
+    }
+  }
+
+  CollectLastUsage::Result last_usage_;
+};
+
+Expr KillAfterLastUse(Expr expr) {
+  expr = CanonicalizeBindings(expr);
+  expr = UnusedTrivialBindingRemover::Apply(expr);
+
+  KillInserter mutator;
+  return mutator(expr);
+}
+
+namespace transform {
+
+Pass KillAfterLastUse() {
+  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
+      [=](Function func, IRModule m, PassContext pc) {
+        return Downcast<Function>(relax::KillAfterLastUse(std::move(func)));
+      };
+  return CreateFunctionPass(pass_func, /*opt_level=*/0, "KillAfterLastUse", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.KillAfterLastUse").set_body_typed(KillAfterLastUse);
+
+}  // namespace transform
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h
index 6e44f07aa6..78e5c31c75 100644
--- a/src/relax/transform/utils.h
+++ b/src/relax/transform/utils.h
@@ -388,6 +388,19 @@ inline String GetCodegenName(const std::string& composite_name) {
  */
 Expr EliminateCommonSubexpr(const Expr& expr, bool call_only = false);
 
+/* \brief Remove use of trivial bindings
+ *
+ * Utility for simplifying relax expressions by folding var bindings
+ * and match shape nodes.  May include other forms of simplification
+ * in the future.  Ideally should be used before constant folding and
+ * eliminating unused bindings.
+ *
+ * \param expr The expression to be canonicalized
+ *
+ * \ret The canonicalized expression
+ */
+Expr CanonicalizeBindings(const Expr& expr);
+
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/tests/python/relax/test_kill_after_last_use.py b/tests/python/relax/test_kill_after_last_use.py
new file mode 100644
index 0000000000..eb6e0777ae
--- /dev/null
+++ b/tests/python/relax/test_kill_after_last_use.py
@@ -0,0 +1,55 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.relax
+import tvm.testing
+
+from tvm.script import ir as I, relax as R
+
+from tvm.relax.transform import KillAfterLastUse
+
+
+def test_basic():
+    @I.ir_module
+    class Before:
+        @R.function(pure=False)
+        def main(x: R.Tensor([16, 32], "float32")):
+            storage = R.memory.alloc_storage(R.shape([2048]), 0, "global", "uint8")
+            y = R.memory.alloc_tensor(storage, 0, R.shape([16, 32]), "float32")
+            _dummy = R.call_packed("add_tensors", [x, y], sinfo_args=(R.Tuple,))
+            z = R.add(x, y)
+            return z
+
+    @I.ir_module
+    class Expected:
+        @R.function(pure=False)
+        def main(x: R.Tensor([16, 32], "float32")):
+            storage = R.memory.alloc_storage(R.shape([2048]), 0, "global", "uint8")
+            y = R.memory.alloc_tensor(storage, 0, R.shape([16, 32]), "float32")
+            _ = R.memory.kill_storage(storage)
+            _dummy = R.call_packed("add_tensors", [x, y], sinfo_args=(R.Tuple,))
+            z = R.add(x, y)
+            _ = R.memory.kill_tensor(y)
+            return z
+
+    After = KillAfterLastUse()(Before)
+    tvm.ir.assert_structural_equal(Expected, After)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()