You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/05/04 10:29:45 UTC

[GitHub] [tvm] lhutton1 opened a new pull request, #11208: [AOT] Calculate used memory at the callsite of primitive functions

lhutton1 opened a new pull request, #11208:
URL: https://github.com/apache/tvm/pull/11208

   Introduces a new pass in the AOT executor called "AnnotateUsedMemory" which applies liveness analysis to the callsite of each primitive function in order to calculate the total size of the live tensors at this point of execution. The result is provided as a function annotation called "used_memory", which can be consumed by later stages of the compiler (e.g. external codegens) to provide more information about the current memory consumption. This can be useful for some optimizations.
   
   Note: this PR is dependent on #11091 so also shows the contents of that PR.
   
   cc @manupa-arm @ekalda @NicolaLancellotti 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] areusch commented on a diff in pull request #11208: [AOT] Calculate used memory at the callsite of primitive functions

Posted by GitBox <gi...@apache.org>.
areusch commented on code in PR #11208:
URL: https://github.com/apache/tvm/pull/11208#discussion_r879752148


##########
src/relay/backend/aot_executor_codegen.cc:
##########
@@ -1063,6 +1064,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
     }
 
     mod = transform::ToANormalForm()(mod);
+    mod = transform::InferType()(mod);
+    mod = transform::AnnotateUsedMemory()(mod);

Review Comment:
   got it, so is this intended as cascader input to tune the down-selection process?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] lhutton1 commented on a diff in pull request #11208: [AOT] Calculate used memory at the callsite of primitive functions

Posted by GitBox <gi...@apache.org>.
lhutton1 commented on code in PR #11208:
URL: https://github.com/apache/tvm/pull/11208#discussion_r868110489


##########
src/relay/backend/manifest_lifetimes.cc:
##########
@@ -0,0 +1,367 @@
+/*

Review Comment:
   No functional changes here, simply moving `manifest_lifetimes.cc` ../ (outside scope of vm) and splitting into .cc/.h



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] lhutton1 commented on pull request #11208: [AOT] Calculate used memory at the callsite of primitive functions

Posted by GitBox <gi...@apache.org>.
lhutton1 commented on PR #11208:
URL: https://github.com/apache/tvm/pull/11208#issuecomment-1279165680

   Hi @zhaoyang-star, thanks for taking a look, its great to see this pass being used elsewhere. The pass currently expects the input to be a module of primitive functions so I would suggest running `AnnotateUsedMemory` after `FuseOps` similar to:
   ```
   mod = relay.transform.InferType()(mod)
   mod = relay.transform.FuseOps()(mod)
   mod = relay.transform.InferType()(mod)
   mod = relay.transform.ToANormalForm()(mod)
   mod = relay.transform.InferType()(mod)
   mod = AnnotateUsedMemory()(mod)
   ```
   
   I did try running your example locally with the above change and this produced the relevant `used_memory` annotations. However, it looks like there is an issue while building the module after having run the `AnnotateUsedMemory` pass. Without digging too much into it I would suspect it's because this pass wasn't considered for the graph executor; only for the AOT executor. I believe changes similar to #11091 would be needed in the graph executor to support A-normal form. Hope this helps :)


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] lhutton1 commented on a diff in pull request #11208: [AOT] Calculate used memory at the callsite of primitive functions

Posted by GitBox <gi...@apache.org>.
lhutton1 commented on code in PR #11208:
URL: https://github.com/apache/tvm/pull/11208#discussion_r885076297


##########
src/relay/backend/annotate_used_memory.cc:
##########
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/annotate_used_memory.cc
+ * \brief Analyzes the used memory at the callsite of primitive functions.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/transform.h>
+
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../transforms/device_aware_visitors.h"
+#include "./liveness_analysis.h"
+#include "./utils.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+
+/*!
+ * \brief Annotates the memory usage of each primitive function by analyzing the liveness
+ * of the input/output tensors at each function callsite and calculating the total amount of
+ * memory these tensors require. This is added as a "used_memory" annotation to the function
+ * in question. In addition, the containing function is annotated with an "io_used_memory"
+ * annotation which refers to the total memory required for the IO tensors.
+ *
+ * A simple example:
+ *
+ * Before:
+ * def @main(%input: Tensor[(1, 2, 2, 4), int8]) -> Tensor[(1, 2, 2, 4), int8] {
+ *   let %x_0 = fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1) -> Tensor[(1, 2, 2, 4), int8] {
+ *     nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0])
+ *   };
+ *   let %x_1 = %x_0(%input);
+ *   %x_1
+ * }
+ *
+ * After:
+ * def @main(%input: Tensor[(1, 2, 2, 4), int8], io_used_memory=32) -> Tensor[(1, 2, 2, 4), int8] {
+ *   let %x_0: fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1, used_memory=32) -> Tensor[(1, 2, 2,
+ * 4), int8] { nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0])
+ *   };
+ *   let %x_1: Tensor[(1, 2, 2, 4), int8] = %x_0(%input);
+ *   %x_1
+ * }
+ *
+ * Note that in the simple example above io_used_memory and used_memory are the same since there
+ * is only one primitive function.
+ */
+class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator {
+ public:
+  AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg,
+                            const transform::LivenessAnalysis& lva)
+      : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {}
+
+  /*!
+   * \brief Mutates the input function. In addition, an "io_used_memory" annotation is
+   * added to the input function which refers to the total size required for the IO
+   * tensors.
+   */
+  Function operator()(const Function& func) {
+    uint64_t io_used_memory = 0;
+
+    // Inputs
+    for (const Var& param : func->params) {
+      Type type = param->checked_type();
+      ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+      io_used_memory += CalculateRelayExprSizeBytes(type);
+    }
+
+    // Outputs
+    Type type = func->body->checked_type();
+    ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+    io_used_memory += CalculateRelayExprSizeBytes(type);
+
+    Expr new_func_body = VisitExpr(func->body);
+    Function new_func = WithFields(func, func->params, new_func_body);
+    return WithAttr(std::move(new_func), "io_used_memory",
+                    tvm::IntImm(tvm::DataType::UInt(64), io_used_memory));
+  }
+
+  /*!
+   * \brief Establish which let bindings have primitive function values.
+   */
+  std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value) {
+    if (const auto* func_node = value.as<FunctionNode>()) {
+      ICHECK(func_node->attrs.HasNonzeroAttr(attr::kPrimitive))
+          << "Expect top-level functions to be primitive.";
+      let_bound_prim_func_.insert(var);
+    }
+    return DeviceAwareExprMutator::PreVisitLetBinding_(var, value);
+  }
+
+  /*!
+   * \brief Visit let nodes and perform one of two actions depending on their value:
+   *
+   * 1. CallNode - Calculate "used_memory" annotation value at the callsite of
+   *               primitive functions.
+   *
+   * 2. FunctionNode - Annotate functions with "used_memory" annotation based on the
+   *                   previous analysis at the callsite.
+   *
+   */
+  Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node) override {
+    Var let_var = post_let_node->var;
+    Expr let_value = IgnoreOnDevice(post_let_node->value);
+
+    if (let_value->IsInstance<CallNode>()) {
+      Call callsite = Downcast<Call>(let_value);
+      if (CheckPrimitiveFunctionCall(callsite)) {
+        Var call_op = Downcast<Var>(callsite->op);
+
+        // Find all the vars that are live at the callsite. This is done by merging the
+        // in and out varset's and then removing the var that references the primitive
+        // function itself since we don't want this included in the calculation.
+        const transform::ControlFlowGraph::NodePtr cfg_node =
+            control_flow_graph_.let_map.at(GetRef<Let>(pre_let_node));
+        transform::VarSet live_tensors = liveness_.live_in.at(cfg_node);
+        const transform::VarSet& live_out = liveness_.live_out.at(cfg_node);
+        live_tensors.insert(live_out.begin(), live_out.end());
+        live_tensors.erase(call_op);
+
+        // Calculate size of live tensors and store to allow annotation when the function
+        // gets visited.
+        uint64_t used_memory = 0;
+        for (const auto& var : live_tensors) {
+          Type type = var->checked_type();
+          ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+          used_memory += CalculateRelayExprSizeBytes(type);
+        }
+        used_memory_annotations_[call_op] = used_memory;

Review Comment:
   Thanks for the discussion, yes a list sounds like a good idea to me so information is not being lost (hopefully its useful to someone in the future)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] lhutton1 commented on pull request #11208: [AOT] Calculate used memory at the callsite of primitive functions

Posted by GitBox <gi...@apache.org>.
lhutton1 commented on PR #11208:
URL: https://github.com/apache/tvm/pull/11208#issuecomment-1120909910

   also cc @mbs-octoml @areusch 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] manupa-arm commented on a diff in pull request #11208: [AOT] Calculate used memory at the callsite of primitive functions

Posted by GitBox <gi...@apache.org>.
manupa-arm commented on code in PR #11208:
URL: https://github.com/apache/tvm/pull/11208#discussion_r883978886


##########
src/relay/backend/annotate_used_memory.cc:
##########
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/annotate_used_memory.cc
+ * \brief Analyzes the used memory at the callsite of primitive functions.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/transform.h>
+
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../transforms/device_aware_visitors.h"
+#include "./liveness_analysis.h"
+#include "./utils.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+
+/*!
+ * \brief Annotates the memory usage of each primitive function by analyzing the liveness
+ * of the input/output tensors at each function callsite and calculating the total amount of
+ * memory these tensors require. This is added as a "used_memory" annotation to the function
+ * in question. In addition, the containing function is annotated with an "io_used_memory"
+ * annotation which refers to the total memory required for the IO tensors.
+ *
+ * A simple example:
+ *
+ * Before:
+ * def @main(%input: Tensor[(1, 2, 2, 4), int8]) -> Tensor[(1, 2, 2, 4), int8] {
+ *   let %x_0 = fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1) -> Tensor[(1, 2, 2, 4), int8] {
+ *     nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0])
+ *   };
+ *   let %x_1 = %x_0(%input);
+ *   %x_1
+ * }
+ *
+ * After:
+ * def @main(%input: Tensor[(1, 2, 2, 4), int8], io_used_memory=32) -> Tensor[(1, 2, 2, 4), int8] {
+ *   let %x_0: fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1, used_memory=32) -> Tensor[(1, 2, 2,
+ * 4), int8] { nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0])
+ *   };
+ *   let %x_1: Tensor[(1, 2, 2, 4), int8] = %x_0(%input);
+ *   %x_1
+ * }
+ *
+ * Note that in the simple example above io_used_memory and used_memory are the same since there
+ * is only one primitive function.
+ */
+class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator {
+ public:
+  AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg,
+                            const transform::LivenessAnalysis& lva)
+      : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {}
+
+  /*!
+   * \brief Mutates the input function. In addition, an "io_used_memory" annotation is
+   * added to the input function which refers to the total size required for the IO
+   * tensors.
+   */
+  Function operator()(const Function& func) {
+    uint64_t io_used_memory = 0;
+
+    // Inputs
+    for (const Var& param : func->params) {
+      Type type = param->checked_type();
+      ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+      io_used_memory += CalculateRelayExprSizeBytes(type);
+    }
+
+    // Outputs
+    Type type = func->body->checked_type();
+    ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+    io_used_memory += CalculateRelayExprSizeBytes(type);
+
+    Expr new_func_body = VisitExpr(func->body);
+    Function new_func = WithFields(func, func->params, new_func_body);
+    return WithAttr(std::move(new_func), "io_used_memory",
+                    tvm::IntImm(tvm::DataType::UInt(64), io_used_memory));
+  }
+
+  /*!
+   * \brief Establish which let bindings have primitive function values.
+   */
+  std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value) {
+    if (const auto* func_node = value.as<FunctionNode>()) {
+      ICHECK(func_node->attrs.HasNonzeroAttr(attr::kPrimitive))
+          << "Expect top-level functions to be primitive.";
+      let_bound_prim_func_.insert(var);
+    }
+    return DeviceAwareExprMutator::PreVisitLetBinding_(var, value);
+  }
+
+  /*!
+   * \brief Visit let nodes and perform one of two actions depending on their value:
+   *
+   * 1. CallNode - Calculate "used_memory" annotation value at the callsite of
+   *               primitive functions.
+   *
+   * 2. FunctionNode - Annotate functions with "used_memory" annotation based on the
+   *                   previous analysis at the callsite.
+   *
+   */
+  Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node) override {
+    Var let_var = post_let_node->var;
+    Expr let_value = IgnoreOnDevice(post_let_node->value);
+
+    if (let_value->IsInstance<CallNode>()) {
+      Call callsite = Downcast<Call>(let_value);
+      if (CheckPrimitiveFunctionCall(callsite)) {
+        Var call_op = Downcast<Var>(callsite->op);
+
+        // Find all the vars that are live at the callsite. This is done by merging the
+        // in and out varset's and then removing the var that references the primitive
+        // function itself since we don't want this included in the calculation.
+        const transform::ControlFlowGraph::NodePtr cfg_node =
+            control_flow_graph_.let_map.at(GetRef<Let>(pre_let_node));
+        transform::VarSet live_tensors = liveness_.live_in.at(cfg_node);
+        const transform::VarSet& live_out = liveness_.live_out.at(cfg_node);
+        live_tensors.insert(live_out.begin(), live_out.end());
+        live_tensors.erase(call_op);
+
+        // Calculate size of live tensors and store to allow annotation when the function
+        // gets visited.
+        uint64_t used_memory = 0;
+        for (const auto& var : live_tensors) {
+          Type type = var->checked_type();
+          ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+          used_memory += CalculateRelayExprSizeBytes(type);
+        }
+        used_memory_annotations_[call_op] = used_memory;

Review Comment:
   List seems fine as well -- we can consume by taking tha max of it for our usecase.
   
   @lhutton1 WDYT?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] altanh commented on a diff in pull request #11208: [AOT] Calculate used memory at the callsite of primitive functions

Posted by GitBox <gi...@apache.org>.
altanh commented on code in PR #11208:
URL: https://github.com/apache/tvm/pull/11208#discussion_r886196916


##########
src/relay/backend/annotate_used_memory.cc:
##########
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/annotate_used_memory.cc
+ * \brief Analyzes the used memory at the callsite of primitive functions.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/transform.h>
+
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../transforms/device_aware_visitors.h"
+#include "./liveness_analysis.h"
+#include "./utils.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+
+/*!
+ * \brief Annotates the memory usage of each primitive function by analyzing the liveness
+ * of the input/output tensors at each function callsite and calculating the total amount of
+ * memory these tensors require. This is added as a "used_memory" annotation to the function
+ * in question. In addition, the containing function is annotated with an "io_used_memory"
+ * annotation which refers to the total memory required for the IO tensors.
+ *
+ * A simple example:
+ *
+ * Before:
+ * def @main(%input: Tensor[(1, 2, 2, 4), int8]) -> Tensor[(1, 2, 2, 4), int8] {
+ *   let %x_0 = fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1) -> Tensor[(1, 2, 2, 4), int8] {
+ *     nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0])
+ *   };
+ *   let %x_1 = %x_0(%input);
+ *   %x_1
+ * }
+ *
+ * After:
+ * def @main(%input: Tensor[(1, 2, 2, 4), int8], io_used_memory=32) -> Tensor[(1, 2, 2, 4), int8] {
+ *   let %x_0: fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1, used_memory=32) -> Tensor[(1, 2, 2,
+ * 4), int8] { nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0])
+ *   };
+ *   let %x_1: Tensor[(1, 2, 2, 4), int8] = %x_0(%input);
+ *   %x_1
+ * }
+ *
+ * Note that in the simple example above io_used_memory and used_memory are the same since there
+ * is only one primitive function.
+ */
+class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator {
+ public:
+  AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg,
+                            const transform::LivenessAnalysis& lva)
+      : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {}
+
+  /*!
+   * \brief Mutates the input function. In addition, an "io_used_memory" annotation is
+   * added to the input function which refers to the total size required for the IO
+   * tensors.
+   */
+  Function operator()(const Function& func) {
+    uint64_t io_used_memory = 0;
+
+    // Inputs
+    for (const Var& param : func->params) {
+      Type type = param->checked_type();
+      ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+      io_used_memory += CalculateRelayExprSizeBytes(type);
+    }
+
+    // Outputs
+    Type type = func->body->checked_type();
+    ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+    io_used_memory += CalculateRelayExprSizeBytes(type);
+
+    Expr new_func_body = VisitExpr(func->body);
+    Function new_func = WithFields(func, func->params, new_func_body);
+    return WithAttr(std::move(new_func), "io_used_memory",
+                    tvm::IntImm(tvm::DataType::UInt(64), io_used_memory));
+  }
+
+  /*!
+   * \brief Establish which let bindings have primitive function values.
+   */
+  std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value) {
+    if (const auto* func_node = value.as<FunctionNode>()) {
+      ICHECK(func_node->attrs.HasNonzeroAttr(attr::kPrimitive))
+          << "Expect top-level functions to be primitive.";
+      let_bound_prim_func_.insert(var);
+    }
+    return DeviceAwareExprMutator::PreVisitLetBinding_(var, value);
+  }
+
+  /*!
+   * \brief Visit let nodes and perform one of two actions depending on their value:
+   *
+   * 1. CallNode - Calculate "used_memory" annotation value at the callsite of
+   *               primitive functions.
+   *
+   * 2. FunctionNode - Annotate functions with "used_memory" annotation based on the
+   *                   previous analysis at the callsite.
+   *
+   */
+  Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node) override {
+    Var let_var = post_let_node->var;
+    Expr let_value = IgnoreOnDevice(post_let_node->value);
+
+    if (let_value->IsInstance<CallNode>()) {
+      Call callsite = Downcast<Call>(let_value);
+      if (CheckPrimitiveFunctionCall(callsite)) {
+        Var call_op = Downcast<Var>(callsite->op);
+
+        // Find all the vars that are live at the callsite. This is done by merging the
+        // in and out varset's and then removing the var that references the primitive
+        // function itself since we don't want this included in the calculation.
+        const transform::ControlFlowGraph::NodePtr cfg_node =
+            control_flow_graph_.let_map.at(GetRef<Let>(pre_let_node));
+        transform::VarSet live_tensors = liveness_.live_in.at(cfg_node);

Review Comment:
   kind of, I think you've got my point- my comment was mainly about how the same function might get different used-mem annotations depending on the incoming ANF ordering. In the unit test, you only rely on the ordering to match up the used mems (which will be correct), but there could be some trickiness if we expect the same function (say, referenced by name or id) to have the same annotation between runs of the pass (*if* ANF ordering is non-deterministic). Let me know if that makes sense! This might not be a problem for your use case anyways.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] lhutton1 commented on a diff in pull request #11208: [AOT] Calculate used memory at the callsite of primitive functions

Posted by GitBox <gi...@apache.org>.
lhutton1 commented on code in PR #11208:
URL: https://github.com/apache/tvm/pull/11208#discussion_r878083413


##########
src/relay/backend/aot/annotate_used_memory.cc:
##########
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/aot/annotate_used_memory.cc
+ * \brief Analyzes the memory pressure at the callsite of primitive functions.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/relay/transform.h>
+
+#include "../../transforms/device_aware_visitors.h"
+#include "../manifest_lifetimes.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+namespace aot {
+
+/*!
+ * \brief Annotates the memory usage of each primitive function by analysing the liveness
+ * of the input/output tensors at the function callsite and calculating the total amount of
+ * memory these tensors require.
+ */
+class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator {
+ public:
+  AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg,
+                            const transform::LivenessAnalysis& lva)
+      : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {}
+
+  /*!
+   * \brief Get the memory required for a primitive Relay function by calculating the total
+   * bytes of the live tensors at the callsite of the function.
+   *
+   * \param live_tensors The tensors that are live when the function is called.
+   * \return int The total number of bytes a function requires.
+   */
+  int GetMemoryUsage(const transform::VarSet& live_tensors) {
+    Array<Type> types_stack = {};
+    int memory_usage = 0;
+
+    for (const Var& var : live_tensors) {
+      Type var_type = var->checked_type();
+      ICHECK(var_type.defined()) << "InferTypes pass should be run before AnnotateUsedMemory pass.";
+      types_stack.push_back(var_type);
+    }
+
+    while (!types_stack.empty()) {

Review Comment:
   With @areusch's suggestion below, I should be able to remove this altogether



##########
src/relay/backend/aot/annotate_used_memory.cc:
##########
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/aot/annotate_used_memory.cc
+ * \brief Analyzes the memory pressure at the callsite of primitive functions.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/relay/transform.h>
+
+#include "../../transforms/device_aware_visitors.h"
+#include "../manifest_lifetimes.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+namespace aot {
+
+/*!
+ * \brief Annotates the memory usage of each primitive function by analysing the liveness
+ * of the input/output tensors at the function callsite and calculating the total amount of
+ * memory these tensors require.
+ */
+class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator {
+ public:
+  AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg,
+                            const transform::LivenessAnalysis& lva)
+      : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {}
+
+  /*!
+   * \brief Get the memory required for a primitive Relay function by calculating the total
+   * bytes of the live tensors at the callsite of the function.
+   *
+   * \param live_tensors The tensors that are live when the function is called.
+   * \return int The total number of bytes a function requires.
+   */
+  int GetMemoryUsage(const transform::VarSet& live_tensors) {
+    Array<Type> types_stack = {};
+    int memory_usage = 0;
+
+    for (const Var& var : live_tensors) {
+      Type var_type = var->checked_type();
+      ICHECK(var_type.defined()) << "InferTypes pass should be run before AnnotateUsedMemory pass.";
+      types_stack.push_back(var_type);
+    }
+
+    while (!types_stack.empty()) {
+      Type current_type = types_stack.back();
+      types_stack.pop_back();
+
+      if (const auto* tt_node = current_type.as<TupleTypeNode>()) {
+        for (const Type& type : tt_node->fields) {
+          types_stack.push_back(type);
+        }
+        continue;
+      } else if (const auto* ft_node = current_type.as<FuncTypeNode>()) {
+        types_stack.push_back(ft_node->ret_type);
+        continue;
+      }
+
+      const auto* tt_node = current_type.as<TensorTypeNode>();
+      ICHECK(tt_node) << "Expected TensorTypeNode but was " << current_type->GetTypeKey();
+      int total_tensor_bytes = GetTensorBytes(tt_node);
+      memory_usage += total_tensor_bytes;
+    }
+    return memory_usage;
+  }
+
+  /*!
+   * \brief Get the number of bytes a tensor requires.
+   *
+   * \param tensor_type_node The checked type of the tensor.
+   * \return int The number of bytes required.
+   */
+  int GetTensorBytes(const TensorTypeNode* tensor_type_node) {
+    PrimExpr size = tensor_type_node->Size();
+    const auto* size_int_imm = size.as<IntImmNode>();
+    ICHECK(size_int_imm) << "Expected tensor size to be an IntImmNode but was "
+                         << size->GetTypeKey();
+
+    int total_size = size_int_imm->value;
+    int dtype_bytes = tensor_type_node->dtype.bytes();
+    return total_size * dtype_bytes;
+  }
+
+  Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node) override {
+    if (const auto* func_node = pre_let_node->value.as<FunctionNode>()) {

Review Comment:
   Great point thanks, in hindsight we should definitely be visiting the callsite rather than the function. Previously I thought I could get all the information needed just visiting the function, since checking if the call node op is a function is a little more complex due to the let bindings getting in the way. But as you both rightly mention this isn't the correct way to do it



##########
src/relay/backend/aot/annotate_used_memory.cc:
##########
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/aot/annotate_used_memory.cc
+ * \brief Analyzes the memory pressure at the callsite of primitive functions.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/relay/transform.h>
+
+#include "../../transforms/device_aware_visitors.h"
+#include "../manifest_lifetimes.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+namespace aot {
+
+/*!
+ * \brief Annotates the memory usage of each primitive function by analysing the liveness
+ * of the input/output tensors at the function callsite and calculating the total amount of
+ * memory these tensors require.
+ */
+class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator {
+ public:
+  AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg,
+                            const transform::LivenessAnalysis& lva)
+      : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {}
+
+  /*!
+   * \brief Get the memory required for a primitive Relay function by calculating the total
+   * bytes of the live tensors at the callsite of the function.
+   *
+   * \param live_tensors The tensors that are live when the function is called.
+   * \return int The total number of bytes a function requires.
+   */
+  int GetMemoryUsage(const transform::VarSet& live_tensors) {
+    Array<Type> types_stack = {};
+    int memory_usage = 0;
+
+    for (const Var& var : live_tensors) {
+      Type var_type = var->checked_type();
+      ICHECK(var_type.defined()) << "InferTypes pass should be run before AnnotateUsedMemory pass.";
+      types_stack.push_back(var_type);
+    }
+
+    while (!types_stack.empty()) {
+      Type current_type = types_stack.back();
+      types_stack.pop_back();
+
+      if (const auto* tt_node = current_type.as<TupleTypeNode>()) {
+        for (const Type& type : tt_node->fields) {
+          types_stack.push_back(type);
+        }
+        continue;
+      } else if (const auto* ft_node = current_type.as<FuncTypeNode>()) {
+        types_stack.push_back(ft_node->ret_type);
+        continue;
+      }
+
+      const auto* tt_node = current_type.as<TensorTypeNode>();
+      ICHECK(tt_node) << "Expected TensorTypeNode but was " << current_type->GetTypeKey();
+      int total_tensor_bytes = GetTensorBytes(tt_node);
+      memory_usage += total_tensor_bytes;
+    }
+    return memory_usage;
+  }
+
+  /*!
+   * \brief Get the number of bytes a tensor requires.
+   *
+   * \param tensor_type_node The checked type of the tensor.
+   * \return int The number of bytes required.
+   */
+  int GetTensorBytes(const TensorTypeNode* tensor_type_node) {

Review Comment:
   Thanks I wasn't aware of these!



##########
src/relay/backend/aot/annotate_used_memory.cc:
##########
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/aot/annotate_used_memory.cc
+ * \brief Analyzes the memory pressure at the callsite of primitive functions.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/relay/transform.h>
+
+#include "../../transforms/device_aware_visitors.h"
+#include "../manifest_lifetimes.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+namespace aot {
+
+/*!
+ * \brief Annotates the memory usage of each primitive function by analysing the liveness
+ * of the input/output tensors at the function callsite and calculating the total amount of
+ * memory these tensors require.
+ */
+class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator {
+ public:
+  AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg,
+                            const transform::LivenessAnalysis& lva)
+      : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {}
+
+  /*!
+   * \brief Get the memory required for a primitive Relay function by calculating the total
+   * bytes of the live tensors at the callsite of the function.
+   *
+   * \param live_tensors The tensors that are live when the function is called.
+   * \return int The total number of bytes a function requires.
+   */
+  int GetMemoryUsage(const transform::VarSet& live_tensors) {
+    Array<Type> types_stack = {};
+    int memory_usage = 0;
+
+    for (const Var& var : live_tensors) {
+      Type var_type = var->checked_type();
+      ICHECK(var_type.defined()) << "InferTypes pass should be run before AnnotateUsedMemory pass.";
+      types_stack.push_back(var_type);
+    }
+
+    while (!types_stack.empty()) {
+      Type current_type = types_stack.back();
+      types_stack.pop_back();
+
+      if (const auto* tt_node = current_type.as<TupleTypeNode>()) {
+        for (const Type& type : tt_node->fields) {
+          types_stack.push_back(type);
+        }
+        continue;
+      } else if (const auto* ft_node = current_type.as<FuncTypeNode>()) {
+        types_stack.push_back(ft_node->ret_type);

Review Comment:
   Yeah this is due to currently visiting the function, when it should really be the call node - which would mean this isn't required



##########
src/relay/backend/aot/annotate_used_memory.cc:
##########
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/aot/annotate_used_memory.cc
+ * \brief Analyzes the memory pressure at the callsite of primitive functions.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/relay/transform.h>
+
+#include "../../transforms/device_aware_visitors.h"
+#include "../manifest_lifetimes.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+namespace aot {
+
+/*!
+ * \brief Annotates the memory usage of each primitive function by analysing the liveness
+ * of the input/output tensors at the function callsite and calculating the total amount of
+ * memory these tensors require.
+ */
+class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator {
+ public:
+  AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg,
+                            const transform::LivenessAnalysis& lva)
+      : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {}
+
+  /*!
+   * \brief Get the memory required for a primitive Relay function by calculating the total
+   * bytes of the live tensors at the callsite of the function.
+   *
+   * \param live_tensors The tensors that are live when the function is called.
+   * \return int The total number of bytes a function requires.
+   */
+  int GetMemoryUsage(const transform::VarSet& live_tensors) {
+    Array<Type> types_stack = {};
+    int memory_usage = 0;

Review Comment:
   Great point, thanks!



##########
src/relay/backend/manifest_lifetimes.cc:
##########
@@ -0,0 +1,367 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/manifest_lifetimes.cc
+ * \brief Analysis and explicit manifestation of variable lifetimes. NOTE: the input IR should be in
+ * ANF and post-memory-lowering (explicit manifestation of allocations).
+ */
+
+#include "manifest_lifetimes.h"
+
+#include <list>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relay {
+namespace transform {
+
+using support::Arena;
+using VarSet = std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>;
+
+ControlFlowGraph ControlFlowGraph::Create(Arena* arena, const Expr& body) {
+  return Creator().Create(arena, body);
+}
+
+ControlFlowGraph ControlFlowGraph::Creator::Create(Arena* arena, const Expr& body) {
+  arena_ = arena;
+  cfg_.entry = BasicBlock::Make(arena);
+  VisitExpr(body, cfg_.entry);
+  return std::move(cfg_);
+}
+
+void ControlFlowGraph::Creator::Succ(BasicBlockPtr from, BasicBlockPtr to) {
+  from->succ.push_back(to);
+  to->pred.push_back(from);
+}
+
+void ControlFlowGraph::Creator::VisitExpr_(const FunctionNode* f, BasicBlockPtr parent) {
+  ICHECK(!in_func_) << "nested functions not supported by CFG analysis";
+  in_func_ = true;
+
+  // Unwrap the nested function and proceed normally.
+  if (f->HasNonzeroAttr(attr::kClosure)) {
+    ICHECK(f->body.as<FunctionNode>());
+    return VisitExpr(Downcast<Function>(f->body)->body, parent);
+  }
+
+  return VisitExpr(f->body, parent);
+}
+
+void ControlFlowGraph::Creator::VisitExpr_(const LetNode* let_node, BasicBlockPtr parent) {
+  Expr expr = GetRef<Expr>(let_node);
+
+  while (const LetNode* inner_let_node = expr.as<LetNode>()) {
+    NodePtr curr_node = Node::Make(arena_, parent, expr);
+
+    ICHECK(!cfg_.let_map.count(expr));
+    cfg_.let_map[expr] = curr_node;
+    cfg_.reverse_post_order.push_back(curr_node);
+
+    // The basic block ends upon reaching control flow, with successor blocks corresponding to the
+    // control flow branch exprs (true/false in If, and one for each clause in Match).
+    if (const IfNode* ite = AsIgnoringOnDevice<IfNode>(inner_let_node->value)) {
+      // Create the basic blocks for each branch and mark them as successors to the current block.
+      BasicBlockPtr t_block = BasicBlock::Make(arena_);
+      BasicBlockPtr f_block = BasicBlock::Make(arena_);
+      Succ(parent, t_block);
+      Succ(parent, f_block);
+
+      VisitExpr(ite->true_branch, t_block);
+      VisitExpr(ite->false_branch, f_block);
+
+      // All subsequent bindings (and/or the body expr) will be in a new basic block.
+      BasicBlockPtr next = BasicBlock::Make(arena_);
+      Succ(t_block, next);
+      Succ(f_block, next);
+      parent = next;
+    } else if (const MatchNode* match = AsIgnoringOnDevice<MatchNode>(inner_let_node->value)) {
+      // Same as above but one for each pattern.
+      std::vector<BasicBlockPtr> clause_blocks;
+      BasicBlockPtr next = BasicBlock::Make(arena_);
+      for (const Clause& clause : match->clauses) {
+        BasicBlockPtr clause_block = BasicBlock::Make(arena_);
+        Succ(parent, clause_block);
+        Succ(clause_block, next);
+        VisitExpr(clause->rhs, clause_block);
+      }
+      parent = next;
+    }
+
+    expr = inner_let_node->body;
+  }
+
+  VisitExpr(expr, parent);
+}
+
+void ControlFlowGraph::Creator::VisitExpr_(const IfNode* if_node, BasicBlockPtr parent) {
+  // TODO(@altanh): is there a way of making this work?
+  LOG(FATAL) << "If expressions should be bound to variables.";
+}
+
+void ControlFlowGraph::Creator::VisitExpr_(const MatchNode* match_node, BasicBlockPtr parent) {
+  // TODO(@altanh): same as If
+  LOG(FATAL) << "Match expressions should be bound to variables.";
+}
+
+VarSet VarUseCollector::VisitExpr_(const VarNode* var_node) { return {GetRef<Var>(var_node)}; }
+
+VarSet VarUseCollector::VisitExpr_(const CallNode* call_node) {
+  VarSet use = VisitExpr(call_node->op);
+  for (const Expr& arg : call_node->args) {
+    VarSet arg_use = VisitExpr(arg);
+    use.insert(arg_use.begin(), arg_use.end());
+  }
+  return use;
+}
+
+VarSet VarUseCollector::VisitExpr_(const TupleNode* tuple_node) {
+  VarSet use;
+  for (const Expr& field : tuple_node->fields) {
+    VarSet field_use = VisitExpr(field);
+    use.insert(field_use.begin(), field_use.end());
+  }
+  return use;
+}
+
+VarSet VarUseCollector::VisitExpr_(const TupleGetItemNode* get_node) {
+  return VisitExpr(get_node->tuple);
+}
+
+VarSet VarUseCollector::VisitExpr_(const IfNode* if_node) { return VisitExpr(if_node->cond); }
+
+VarSet VarUseCollector::VisitExpr_(const MatchNode* match_node) {
+  return VisitExpr(match_node->data);
+}
+
+UseDefAnalysis UseDefAnalysis::Analyze(const CFG& cfg) {
+  UseDefAnalysis a;
+
+  // One pass is sufficient.
+  for (auto it = cfg.reverse_post_order.begin(); it != cfg.reverse_post_order.end(); ++it) {
+    const CFG::NodePtr& node = *it;
+    if (const LetNode* let_node = AsIgnoringOnDevice<LetNode>(node->expr)) {
+      a.use[node] = a.use_collector.VisitExpr(let_node->value);
+      a.def[node] = let_node->var;
+    } else {
+      a.use[node] = a.use_collector.VisitExpr(node->expr);
+      a.def[node] = Var();
+    }
+  }
+
+  return a;
+}
+
+bool SetEqual(const VarSet& a, const VarSet& b) {
+  if (a.size() != b.size()) {
+    return false;
+  }
+  for (auto& xa : a) {
+    if (!b.count(xa)) {
+      return false;
+    }
+  }
+  return true;
+}
+
+LivenessAnalysis LivenessAnalysis::Analyze(const ControlFlowGraph& cfg,
+                                           const UseDefAnalysis& use_def) {
+  LivenessAnalysis a;
+  std::list<CFG::NodePtr> worklist;
+
+  // Initialize worklist to post-order traversal for quick convergence.
+  worklist.insert(worklist.end(), cfg.reverse_post_order.rbegin(), cfg.reverse_post_order.rend());
+
+  // See https://lambda.uta.edu/cse5317/notes/node40.html for an overview of the algorithm.
+  auto visitor = [&](const CFG::NodePtr n) {
+    VarSet old_in_n = a.live_in[n];
+    VarSet old_out_n = a.live_out[n];
+
+    a.live_in[n] = use_def.use.at(n);
+    for (const Var& v : a.live_out[n]) {
+      if (!v.same_as(use_def.def.at(n))) {
+        a.live_in[n].insert(v);
+      }
+    }
+
+    a.live_out[n] = VarSet();
+    for (const CFG::NodePtr& s : n->GetSucc()) {
+      a.live_out[n].insert(a.live_in[s].begin(), a.live_in[s].end());
+    }
+
+    if (SetEqual(old_in_n, a.live_in[n]) && SetEqual(old_out_n, a.live_out[n])) {
+      // No need to update the worklist.
+    } else {
+      // Add predecessor nodes back to worklist (no need to add successors, since each node's
+      // in/out sets are not dependent on its predecessors).
+      for (const CFG::NodePtr& p : n->GetPred()) {
+        worklist.push_back(p);
+      }
+    }
+  };
+
+  while (!worklist.empty()) {
+    const CFG::NodePtr n = worklist.front();
+    worklist.pop_front();
+    visitor(n);
+  }
+
+  return a;
+}
+
+Expr KillInserter::VisitExpr_(const LetNode* let_node) {
+  Expr expr = GetRef<Expr>(let_node);
+  LetList ll;
+
+  while (const LetNode* inner_let_node = expr.as<LetNode>()) {
+    ll.Push(inner_let_node->var, VisitExpr(inner_let_node->value));
+
+    ICHECK(!inner_let_node->value.as<VarNode>()) << "aliasing should have been eliminated.";
+    ICHECK(cfg_->let_map.count(expr)) << "all Let exprs should be mapped in the CFG";
+
+    const ControlFlowGraph::NodePtr n = cfg_->let_map.at(expr);
+
+    const VarSet& li = lva_->live_in.at(n);
+    const VarSet& lo = lva_->live_out.at(n);
+
+    // Killed vars = live in - live out.
+    VarSet kills;
+    for (const Var& v : li) {
+      if (!lo.count(v)) {
+        kills.insert(v);
+      }
+    }
+
+    for (const Var& v : kills) {
+      ll.Push(Call(Op::Get("memory.kill"), {v}));
+    }
+
+    expr = inner_let_node->body;
+  }
+
+  return ll.Get(VisitExpr(expr));
+}
+
+Expr AliasEliminator::VisitExpr_(const LetNode* let_node) {
+  Expr expr = GetRef<Expr>(let_node);
+  LetList ll;
+  std::vector<Var> aliased_vars;
+
+  while (const LetNode* inner_let_node = expr.as<LetNode>()) {
+    const Var& var = inner_let_node->var;
+    const Expr& val = inner_let_node->value;
+    bool aliased = false;
+    ICHECK(!alias_.count(var));
+
+    if (const VarNode* alias_of_n = AsIgnoringOnDevice<VarNode>(val)) {
+      alias_[var] = Downcast<Var>(VisitExpr_(alias_of_n));
+      aliased = true;
+    } else if (AsIgnoringOnDevice<CallNode>(val)) {
+      // Copying to the same device is aliasing.
+      // WARNING: this must be kept in sync with the VM compiler logic in
+      // src/relay/backend/vm/compiler.cc, line 541, in DeviceAwareVisitExpr_(const CallNode*).

Review Comment:
   Sure, I'll try to move some VM specific stuff back



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] lhutton1 commented on a diff in pull request #11208: [AOT] Calculate used memory at the callsite of primitive functions

Posted by GitBox <gi...@apache.org>.
lhutton1 commented on code in PR #11208:
URL: https://github.com/apache/tvm/pull/11208#discussion_r885075902


##########
src/relay/backend/annotate_used_memory.cc:
##########
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/annotate_used_memory.cc
+ * \brief Analyzes the used memory at the callsite of primitive functions.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/transform.h>
+
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../transforms/device_aware_visitors.h"
+#include "./liveness_analysis.h"
+#include "./utils.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+
+/*!
+ * \brief Annotates the memory usage of each primitive function by analyzing the liveness
+ * of the input/output tensors at each function callsite and calculating the total amount of
+ * memory these tensors require. This is added as a "used_memory" annotation to the function
+ * in question. In addition, the containing function is annotated with an "io_used_memory"
+ * annotation which refers to the total memory required for the IO tensors.
+ *
+ * A simple example:
+ *
+ * Before:
+ * def @main(%input: Tensor[(1, 2, 2, 4), int8]) -> Tensor[(1, 2, 2, 4), int8] {
+ *   let %x_0 = fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1) -> Tensor[(1, 2, 2, 4), int8] {
+ *     nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0])
+ *   };
+ *   let %x_1 = %x_0(%input);
+ *   %x_1
+ * }
+ *
+ * After:
+ * def @main(%input: Tensor[(1, 2, 2, 4), int8], io_used_memory=32) -> Tensor[(1, 2, 2, 4), int8] {
+ *   let %x_0: fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1, used_memory=32) -> Tensor[(1, 2, 2,
+ * 4), int8] { nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0])
+ *   };
+ *   let %x_1: Tensor[(1, 2, 2, 4), int8] = %x_0(%input);
+ *   %x_1
+ * }
+ *
+ * Note that in the simple example above io_used_memory and used_memory are the same since there
+ * is only one primitive function.
+ */
+class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator {
+ public:
+  AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg,
+                            const transform::LivenessAnalysis& lva)
+      : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {}
+
+  /*!
+   * \brief Mutates the input function. In addition, an "io_used_memory" annotation is
+   * added to the input function which refers to the total size required for the IO
+   * tensors.
+   */
+  Function operator()(const Function& func) {
+    uint64_t io_used_memory = 0;
+
+    // Inputs
+    for (const Var& param : func->params) {
+      Type type = param->checked_type();
+      ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+      io_used_memory += CalculateRelayExprSizeBytes(type);
+    }
+
+    // Outputs
+    Type type = func->body->checked_type();
+    ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+    io_used_memory += CalculateRelayExprSizeBytes(type);
+
+    Expr new_func_body = VisitExpr(func->body);
+    Function new_func = WithFields(func, func->params, new_func_body);
+    return WithAttr(std::move(new_func), "io_used_memory",
+                    tvm::IntImm(tvm::DataType::UInt(64), io_used_memory));
+  }
+
+  /*!
+   * \brief Establish which let bindings have primitive function values.
+   */
+  std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value) {
+    if (const auto* func_node = value.as<FunctionNode>()) {
+      ICHECK(func_node->attrs.HasNonzeroAttr(attr::kPrimitive))
+          << "Expect top-level functions to be primitive.";
+      let_bound_prim_func_.insert(var);
+    }
+    return DeviceAwareExprMutator::PreVisitLetBinding_(var, value);
+  }
+
+  /*!
+   * \brief Visit let nodes and perform one of two actions depending on their value:
+   *
+   * 1. CallNode - Calculate "used_memory" annotation value at the callsite of
+   *               primitive functions.
+   *
+   * 2. FunctionNode - Annotate functions with "used_memory" annotation based on the
+   *                   previous analysis at the callsite.
+   *
+   */
+  Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node) override {
+    Var let_var = post_let_node->var;
+    Expr let_value = IgnoreOnDevice(post_let_node->value);
+
+    if (let_value->IsInstance<CallNode>()) {
+      Call callsite = Downcast<Call>(let_value);
+      if (CheckPrimitiveFunctionCall(callsite)) {
+        Var call_op = Downcast<Var>(callsite->op);
+
+        // Find all the vars that are live at the callsite. This is done by merging the
+        // in and out varset's and then removing the var that references the primitive
+        // function itself since we don't want this included in the calculation.
+        const transform::ControlFlowGraph::NodePtr cfg_node =
+            control_flow_graph_.let_map.at(GetRef<Let>(pre_let_node));
+        transform::VarSet live_tensors = liveness_.live_in.at(cfg_node);
+        const transform::VarSet& live_out = liveness_.live_out.at(cfg_node);
+        live_tensors.insert(live_out.begin(), live_out.end());
+        live_tensors.erase(call_op);
+
+        // Calculate size of live tensors and store to allow annotation when the function
+        // gets visited.
+        uint64_t used_memory = 0;
+        for (const auto& var : live_tensors) {
+          Type type = var->checked_type();
+          ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+          used_memory += CalculateRelayExprSizeBytes(type);

Review Comment:
   Added



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] altanh commented on a diff in pull request #11208: [AOT] Calculate used memory at the callsite of primitive functions

Posted by GitBox <gi...@apache.org>.
altanh commented on code in PR #11208:
URL: https://github.com/apache/tvm/pull/11208#discussion_r883073357


##########
src/relay/backend/annotate_used_memory.cc:
##########
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/annotate_used_memory.cc
+ * \brief Analyzes the used memory at the callsite of primitive functions.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/transform.h>
+
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../transforms/device_aware_visitors.h"
+#include "./liveness_analysis.h"
+#include "./utils.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+
+/*!
+ * \brief Annotates the memory usage of each primitive function by analyzing the liveness

Review Comment:
   would it be more correct to say "minimum required memory" or something along those lines? I'm imagining a scenario where the primitive function internally allocates some buffers (and so the actual memory usage is higher than "live in + live out")



##########
src/relay/backend/annotate_used_memory.cc:
##########
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/annotate_used_memory.cc
+ * \brief Analyzes the used memory at the callsite of primitive functions.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/transform.h>
+
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../transforms/device_aware_visitors.h"
+#include "./liveness_analysis.h"
+#include "./utils.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+
+/*!
+ * \brief Annotates the memory usage of each primitive function by analyzing the liveness
+ * of the input/output tensors at each function callsite and calculating the total amount of
+ * memory these tensors require. This is added as a "used_memory" annotation to the function
+ * in question. In addition, the containing function is annotated with an "io_used_memory"
+ * annotation which refers to the total memory required for the IO tensors.
+ *
+ * A simple example:
+ *
+ * Before:
+ * def @main(%input: Tensor[(1, 2, 2, 4), int8]) -> Tensor[(1, 2, 2, 4), int8] {
+ *   let %x_0 = fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1) -> Tensor[(1, 2, 2, 4), int8] {
+ *     nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0])
+ *   };
+ *   let %x_1 = %x_0(%input);
+ *   %x_1
+ * }
+ *
+ * After:
+ * def @main(%input: Tensor[(1, 2, 2, 4), int8], io_used_memory=32) -> Tensor[(1, 2, 2, 4), int8] {
+ *   let %x_0: fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1, used_memory=32) -> Tensor[(1, 2, 2,
+ * 4), int8] { nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0])
+ *   };
+ *   let %x_1: Tensor[(1, 2, 2, 4), int8] = %x_0(%input);
+ *   %x_1
+ * }
+ *
+ * Note that in the simple example above io_used_memory and used_memory are the same since there
+ * is only one primitive function.
+ */
+class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator {
+ public:
+  AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg,
+                            const transform::LivenessAnalysis& lva)
+      : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {}
+
+  /*!
+   * \brief Mutates the input function. In addition, an "io_used_memory" annotation is
+   * added to the input function which refers to the total size required for the IO
+   * tensors.
+   */
+  Function operator()(const Function& func) {
+    uint64_t io_used_memory = 0;
+
+    // Inputs
+    for (const Var& param : func->params) {
+      Type type = param->checked_type();
+      ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+      io_used_memory += CalculateRelayExprSizeBytes(type);
+    }
+
+    // Outputs
+    Type type = func->body->checked_type();
+    ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+    io_used_memory += CalculateRelayExprSizeBytes(type);
+
+    Expr new_func_body = VisitExpr(func->body);
+    Function new_func = WithFields(func, func->params, new_func_body);
+    return WithAttr(std::move(new_func), "io_used_memory",
+                    tvm::IntImm(tvm::DataType::UInt(64), io_used_memory));
+  }
+
+  /*!
+   * \brief Establish which let bindings have primitive function values.
+   */
+  std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value) {
+    if (const auto* func_node = value.as<FunctionNode>()) {
+      ICHECK(func_node->attrs.HasNonzeroAttr(attr::kPrimitive))
+          << "Expect top-level functions to be primitive.";
+      let_bound_prim_func_.insert(var);
+    }
+    return DeviceAwareExprMutator::PreVisitLetBinding_(var, value);
+  }
+
+  /*!
+   * \brief Visit let nodes and perform one of two actions depending on their value:
+   *
+   * 1. CallNode - Calculate "used_memory" annotation value at the callsite of
+   *               primitive functions.
+   *
+   * 2. FunctionNode - Annotate functions with "used_memory" annotation based on the
+   *                   previous analysis at the callsite.
+   *
+   */
+  Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node) override {
+    Var let_var = post_let_node->var;
+    Expr let_value = IgnoreOnDevice(post_let_node->value);
+
+    if (let_value->IsInstance<CallNode>()) {
+      Call callsite = Downcast<Call>(let_value);
+      if (CheckPrimitiveFunctionCall(callsite)) {
+        Var call_op = Downcast<Var>(callsite->op);
+
+        // Find all the vars that are live at the callsite. This is done by merging the
+        // in and out varset's and then removing the var that references the primitive
+        // function itself since we don't want this included in the calculation.
+        const transform::ControlFlowGraph::NodePtr cfg_node =
+            control_flow_graph_.let_map.at(GetRef<Let>(pre_let_node));
+        transform::VarSet live_tensors = liveness_.live_in.at(cfg_node);
+        const transform::VarSet& live_out = liveness_.live_out.at(cfg_node);
+        live_tensors.insert(live_out.begin(), live_out.end());
+        live_tensors.erase(call_op);
+
+        // Calculate size of live tensors and store to allow annotation when the function
+        // gets visited.
+        uint64_t used_memory = 0;
+        for (const auto& var : live_tensors) {
+          Type type = var->checked_type();
+          ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+          used_memory += CalculateRelayExprSizeBytes(type);

Review Comment:
   do we have a way of gracefully aborting this analysis if any shapes are dynamic? maybe a guard higher up where this pass gets applied?



##########
src/relay/backend/annotate_used_memory.cc:
##########
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/annotate_used_memory.cc
+ * \brief Analyzes the used memory at the callsite of primitive functions.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/transform.h>
+
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../transforms/device_aware_visitors.h"
+#include "./liveness_analysis.h"
+#include "./utils.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+
+/*!
+ * \brief Annotates the memory usage of each primitive function by analyzing the liveness
+ * of the input/output tensors at each function callsite and calculating the total amount of
+ * memory these tensors require. This is added as a "used_memory" annotation to the function
+ * in question. In addition, the containing function is annotated with an "io_used_memory"
+ * annotation which refers to the total memory required for the IO tensors.
+ *
+ * A simple example:
+ *
+ * Before:
+ * def @main(%input: Tensor[(1, 2, 2, 4), int8]) -> Tensor[(1, 2, 2, 4), int8] {
+ *   let %x_0 = fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1) -> Tensor[(1, 2, 2, 4), int8] {
+ *     nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0])
+ *   };
+ *   let %x_1 = %x_0(%input);
+ *   %x_1
+ * }
+ *
+ * After:
+ * def @main(%input: Tensor[(1, 2, 2, 4), int8], io_used_memory=32) -> Tensor[(1, 2, 2, 4), int8] {
+ *   let %x_0: fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1, used_memory=32) -> Tensor[(1, 2, 2,
+ * 4), int8] { nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0])
+ *   };
+ *   let %x_1: Tensor[(1, 2, 2, 4), int8] = %x_0(%input);
+ *   %x_1
+ * }
+ *
+ * Note that in the simple example above io_used_memory and used_memory are the same since there
+ * is only one primitive function.
+ */
+class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator {
+ public:
+  AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg,
+                            const transform::LivenessAnalysis& lva)
+      : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {}
+
+  /*!
+   * \brief Mutates the input function. In addition, an "io_used_memory" annotation is
+   * added to the input function which refers to the total size required for the IO
+   * tensors.
+   */
+  Function operator()(const Function& func) {
+    uint64_t io_used_memory = 0;
+
+    // Inputs
+    for (const Var& param : func->params) {
+      Type type = param->checked_type();
+      ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+      io_used_memory += CalculateRelayExprSizeBytes(type);
+    }
+
+    // Outputs
+    Type type = func->body->checked_type();
+    ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+    io_used_memory += CalculateRelayExprSizeBytes(type);
+
+    Expr new_func_body = VisitExpr(func->body);
+    Function new_func = WithFields(func, func->params, new_func_body);
+    return WithAttr(std::move(new_func), "io_used_memory",
+                    tvm::IntImm(tvm::DataType::UInt(64), io_used_memory));
+  }
+
+  /*!
+   * \brief Establish which let bindings have primitive function values.
+   */
+  std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value) {
+    if (const auto* func_node = value.as<FunctionNode>()) {
+      ICHECK(func_node->attrs.HasNonzeroAttr(attr::kPrimitive))
+          << "Expect top-level functions to be primitive.";
+      let_bound_prim_func_.insert(var);
+    }
+    return DeviceAwareExprMutator::PreVisitLetBinding_(var, value);
+  }
+
+  /*!
+   * \brief Visit let nodes and perform one of two actions depending on their value:
+   *
+   * 1. CallNode - Calculate "used_memory" annotation value at the callsite of
+   *               primitive functions.
+   *
+   * 2. FunctionNode - Annotate functions with "used_memory" annotation based on the
+   *                   previous analysis at the callsite.
+   *
+   */
+  Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node) override {
+    Var let_var = post_let_node->var;
+    Expr let_value = IgnoreOnDevice(post_let_node->value);
+
+    if (let_value->IsInstance<CallNode>()) {
+      Call callsite = Downcast<Call>(let_value);
+      if (CheckPrimitiveFunctionCall(callsite)) {
+        Var call_op = Downcast<Var>(callsite->op);
+
+        // Find all the vars that are live at the callsite. This is done by merging the
+        // in and out varset's and then removing the var that references the primitive
+        // function itself since we don't want this included in the calculation.
+        const transform::ControlFlowGraph::NodePtr cfg_node =
+            control_flow_graph_.let_map.at(GetRef<Let>(pre_let_node));
+        transform::VarSet live_tensors = liveness_.live_in.at(cfg_node);

Review Comment:
   since the live-in and live-out sets are calculated assuming a fixed/linearized execution order (induced by the input ANF), I think it's possible that the used memory calculation could end up being non-deterministic if ANF linearization is non-deterministic. I'm actually not sure if that's the case with our ANF implementation though, but just something worth keeping in mind. I'm mainly thinking in the case of e.g.
   ```
   y = f(x)
   z0 = g0(y)
   ...
   zn = gn(y)
   ```
   where the `zi`s could really be in any permutation depending on how they get linearized.
   
   could you check if the unit tests cover this?



##########
src/relay/backend/annotate_used_memory.cc:
##########
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/annotate_used_memory.cc
+ * \brief Analyzes the used memory at the callsite of primitive functions.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/transform.h>
+
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../transforms/device_aware_visitors.h"
+#include "./liveness_analysis.h"
+#include "./utils.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+
+/*!
+ * \brief Annotates the memory usage of each primitive function by analyzing the liveness
+ * of the input/output tensors at each function callsite and calculating the total amount of
+ * memory these tensors require. This is added as a "used_memory" annotation to the function
+ * in question. In addition, the containing function is annotated with an "io_used_memory"
+ * annotation which refers to the total memory required for the IO tensors.
+ *
+ * A simple example:
+ *
+ * Before:
+ * def @main(%input: Tensor[(1, 2, 2, 4), int8]) -> Tensor[(1, 2, 2, 4), int8] {
+ *   let %x_0 = fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1) -> Tensor[(1, 2, 2, 4), int8] {
+ *     nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0])
+ *   };
+ *   let %x_1 = %x_0(%input);
+ *   %x_1
+ * }
+ *
+ * After:
+ * def @main(%input: Tensor[(1, 2, 2, 4), int8], io_used_memory=32) -> Tensor[(1, 2, 2, 4), int8] {
+ *   let %x_0: fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1, used_memory=32) -> Tensor[(1, 2, 2,
+ * 4), int8] { nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0])
+ *   };
+ *   let %x_1: Tensor[(1, 2, 2, 4), int8] = %x_0(%input);
+ *   %x_1
+ * }
+ *
+ * Note that in the simple example above io_used_memory and used_memory are the same since there
+ * is only one primitive function.
+ */
+class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator {
+ public:
+  AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg,
+                            const transform::LivenessAnalysis& lva)
+      : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {}
+
+  /*!
+   * \brief Mutates the input function. In addition, an "io_used_memory" annotation is
+   * added to the input function which refers to the total size required for the IO
+   * tensors.
+   */
+  Function operator()(const Function& func) {
+    uint64_t io_used_memory = 0;
+
+    // Inputs
+    for (const Var& param : func->params) {
+      Type type = param->checked_type();
+      ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+      io_used_memory += CalculateRelayExprSizeBytes(type);
+    }
+
+    // Outputs
+    Type type = func->body->checked_type();
+    ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+    io_used_memory += CalculateRelayExprSizeBytes(type);
+
+    Expr new_func_body = VisitExpr(func->body);
+    Function new_func = WithFields(func, func->params, new_func_body);
+    return WithAttr(std::move(new_func), "io_used_memory",
+                    tvm::IntImm(tvm::DataType::UInt(64), io_used_memory));
+  }
+
+  /*!
+   * \brief Establish which let bindings have primitive function values.
+   */
+  std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value) {
+    if (const auto* func_node = value.as<FunctionNode>()) {
+      ICHECK(func_node->attrs.HasNonzeroAttr(attr::kPrimitive))
+          << "Expect top-level functions to be primitive.";
+      let_bound_prim_func_.insert(var);
+    }
+    return DeviceAwareExprMutator::PreVisitLetBinding_(var, value);
+  }
+
+  /*!
+   * \brief Visit let nodes and perform one of two actions depending on their value:
+   *
+   * 1. CallNode - Calculate "used_memory" annotation value at the callsite of
+   *               primitive functions.
+   *
+   * 2. FunctionNode - Annotate functions with "used_memory" annotation based on the
+   *                   previous analysis at the callsite.
+   *
+   */
+  Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node) override {
+    Var let_var = post_let_node->var;
+    Expr let_value = IgnoreOnDevice(post_let_node->value);
+
+    if (let_value->IsInstance<CallNode>()) {
+      Call callsite = Downcast<Call>(let_value);
+      if (CheckPrimitiveFunctionCall(callsite)) {
+        Var call_op = Downcast<Var>(callsite->op);
+
+        // Find all the vars that are live at the callsite. This is done by merging the
+        // in and out varset's and then removing the var that references the primitive
+        // function itself since we don't want this included in the calculation.
+        const transform::ControlFlowGraph::NodePtr cfg_node =
+            control_flow_graph_.let_map.at(GetRef<Let>(pre_let_node));
+        transform::VarSet live_tensors = liveness_.live_in.at(cfg_node);
+        const transform::VarSet& live_out = liveness_.live_out.at(cfg_node);
+        live_tensors.insert(live_out.begin(), live_out.end());
+        live_tensors.erase(call_op);
+
+        // Calculate size of live tensors and store to allow annotation when the function
+        // gets visited.
+        uint64_t used_memory = 0;
+        for (const auto& var : live_tensors) {
+          Type type = var->checked_type();
+          ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+          used_memory += CalculateRelayExprSizeBytes(type);
+        }
+        used_memory_annotations_[call_op] = used_memory;

Review Comment:
   I think the IR form at this point prevents this, but I'd add a check that call_op has not already been annotated. If I understand correctly, this pass assumes each primitive func is used at most once.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] altanh commented on a diff in pull request #11208: [AOT] Calculate used memory at the callsite of primitive functions

Posted by GitBox <gi...@apache.org>.
altanh commented on code in PR #11208:
URL: https://github.com/apache/tvm/pull/11208#discussion_r883864440


##########
src/relay/backend/annotate_used_memory.cc:
##########
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/annotate_used_memory.cc
+ * \brief Analyzes the used memory at the callsite of primitive functions.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/transform.h>
+
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../transforms/device_aware_visitors.h"
+#include "./liveness_analysis.h"
+#include "./utils.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+
+/*!
+ * \brief Annotates the memory usage of each primitive function by analyzing the liveness
+ * of the input/output tensors at each function callsite and calculating the total amount of
+ * memory these tensors require. This is added as a "used_memory" annotation to the function
+ * in question. In addition, the containing function is annotated with an "io_used_memory"
+ * annotation which refers to the total memory required for the IO tensors.
+ *
+ * A simple example:
+ *
+ * Before:
+ * def @main(%input: Tensor[(1, 2, 2, 4), int8]) -> Tensor[(1, 2, 2, 4), int8] {
+ *   let %x_0 = fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1) -> Tensor[(1, 2, 2, 4), int8] {
+ *     nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0])
+ *   };
+ *   let %x_1 = %x_0(%input);
+ *   %x_1
+ * }
+ *
+ * After:
+ * def @main(%input: Tensor[(1, 2, 2, 4), int8], io_used_memory=32) -> Tensor[(1, 2, 2, 4), int8] {
+ *   let %x_0: fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1, used_memory=32) -> Tensor[(1, 2, 2,
+ * 4), int8] { nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0])
+ *   };
+ *   let %x_1: Tensor[(1, 2, 2, 4), int8] = %x_0(%input);
+ *   %x_1
+ * }
+ *
+ * Note that in the simple example above io_used_memory and used_memory are the same since there
+ * is only one primitive function.
+ */
+class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator {
+ public:
+  AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg,
+                            const transform::LivenessAnalysis& lva)
+      : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {}
+
+  /*!
+   * \brief Mutates the input function. In addition, an "io_used_memory" annotation is
+   * added to the input function which refers to the total size required for the IO
+   * tensors.
+   */
+  Function operator()(const Function& func) {
+    uint64_t io_used_memory = 0;
+
+    // Inputs
+    for (const Var& param : func->params) {
+      Type type = param->checked_type();
+      ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+      io_used_memory += CalculateRelayExprSizeBytes(type);
+    }
+
+    // Outputs
+    Type type = func->body->checked_type();
+    ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+    io_used_memory += CalculateRelayExprSizeBytes(type);
+
+    Expr new_func_body = VisitExpr(func->body);
+    Function new_func = WithFields(func, func->params, new_func_body);
+    return WithAttr(std::move(new_func), "io_used_memory",
+                    tvm::IntImm(tvm::DataType::UInt(64), io_used_memory));
+  }
+
+  /*!
+   * \brief Establish which let bindings have primitive function values.
+   */
+  std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value) {
+    if (const auto* func_node = value.as<FunctionNode>()) {
+      ICHECK(func_node->attrs.HasNonzeroAttr(attr::kPrimitive))
+          << "Expect top-level functions to be primitive.";
+      let_bound_prim_func_.insert(var);
+    }
+    return DeviceAwareExprMutator::PreVisitLetBinding_(var, value);
+  }
+
+  /*!
+   * \brief Visit let nodes and perform one of two actions depending on their value:
+   *
+   * 1. CallNode - Calculate "used_memory" annotation value at the callsite of
+   *               primitive functions.
+   *
+   * 2. FunctionNode - Annotate functions with "used_memory" annotation based on the
+   *                   previous analysis at the callsite.
+   *
+   */
+  Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node) override {
+    Var let_var = post_let_node->var;
+    Expr let_value = IgnoreOnDevice(post_let_node->value);
+
+    if (let_value->IsInstance<CallNode>()) {
+      Call callsite = Downcast<Call>(let_value);
+      if (CheckPrimitiveFunctionCall(callsite)) {
+        Var call_op = Downcast<Var>(callsite->op);
+
+        // Find all the vars that are live at the callsite. This is done by merging the
+        // in and out varset's and then removing the var that references the primitive
+        // function itself since we don't want this included in the calculation.
+        const transform::ControlFlowGraph::NodePtr cfg_node =
+            control_flow_graph_.let_map.at(GetRef<Let>(pre_let_node));
+        transform::VarSet live_tensors = liveness_.live_in.at(cfg_node);
+        const transform::VarSet& live_out = liveness_.live_out.at(cfg_node);
+        live_tensors.insert(live_out.begin(), live_out.end());
+        live_tensors.erase(call_op);
+
+        // Calculate size of live tensors and store to allow annotation when the function
+        // gets visited.
+        uint64_t used_memory = 0;
+        for (const auto& var : live_tensors) {
+          Type type = var->checked_type();
+          ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+          used_memory += CalculateRelayExprSizeBytes(type);
+        }
+        used_memory_annotations_[call_op] = used_memory;

Review Comment:
   that feels a bit weird to me if we allow primitive funcs to be called multiple times (but this is a sound over-approximation). Would it be too much to make the annotation a list of memory usages, corresponding to different callsites? This might be sufficient for your use case though, so I don't have a strong preference other than leaning towards future-proofing.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] lhutton1 commented on pull request #11208: [AOT] Calculate used memory at the callsite of primitive functions

Posted by GitBox <gi...@apache.org>.
lhutton1 commented on PR #11208:
URL: https://github.com/apache/tvm/pull/11208#issuecomment-1138783701

   Apologies for the delay, this is ready for another look!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] manupa-arm commented on a diff in pull request #11208: [AOT] Calculate used memory at the callsite of primitive functions

Posted by GitBox <gi...@apache.org>.
manupa-arm commented on code in PR #11208:
URL: https://github.com/apache/tvm/pull/11208#discussion_r883470240


##########
src/relay/backend/annotate_used_memory.cc:
##########
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/annotate_used_memory.cc
+ * \brief Analyzes the used memory at the callsite of primitive functions.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/transform.h>
+
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../transforms/device_aware_visitors.h"
+#include "./liveness_analysis.h"
+#include "./utils.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+
+/*!
+ * \brief Annotates the memory usage of each primitive function by analyzing the liveness

Review Comment:
   I think "usage *of*" is the misleading bit here. It should be something to the lines used at the time of the call is done.



##########
src/relay/backend/annotate_used_memory.cc:
##########
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/annotate_used_memory.cc
+ * \brief Analyzes the used memory at the callsite of primitive functions.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/transform.h>
+
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../transforms/device_aware_visitors.h"
+#include "./liveness_analysis.h"
+#include "./utils.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+
+/*!
+ * \brief Annotates the memory usage of each primitive function by analyzing the liveness

Review Comment:
   I think "usage **of**" is the misleading bit here. It should be something to the lines used at the time of the call is done.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] manupa-arm commented on a diff in pull request #11208: [AOT] Calculate used memory at the callsite of primitive functions

Posted by GitBox <gi...@apache.org>.
manupa-arm commented on code in PR #11208:
URL: https://github.com/apache/tvm/pull/11208#discussion_r883978886


##########
src/relay/backend/annotate_used_memory.cc:
##########
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/annotate_used_memory.cc
+ * \brief Analyzes the used memory at the callsite of primitive functions.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/transform.h>
+
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../transforms/device_aware_visitors.h"
+#include "./liveness_analysis.h"
+#include "./utils.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+
+/*!
+ * \brief Annotates the memory usage of each primitive function by analyzing the liveness
+ * of the input/output tensors at each function callsite and calculating the total amount of
+ * memory these tensors require. This is added as a "used_memory" annotation to the function
+ * in question. In addition, the containing function is annotated with an "io_used_memory"
+ * annotation which refers to the total memory required for the IO tensors.
+ *
+ * A simple example:
+ *
+ * Before:
+ * def @main(%input: Tensor[(1, 2, 2, 4), int8]) -> Tensor[(1, 2, 2, 4), int8] {
+ *   let %x_0 = fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1) -> Tensor[(1, 2, 2, 4), int8] {
+ *     nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0])
+ *   };
+ *   let %x_1 = %x_0(%input);
+ *   %x_1
+ * }
+ *
+ * After:
+ * def @main(%input: Tensor[(1, 2, 2, 4), int8], io_used_memory=32) -> Tensor[(1, 2, 2, 4), int8] {
+ *   let %x_0: fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1, used_memory=32) -> Tensor[(1, 2, 2,
+ * 4), int8] { nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0])
+ *   };
+ *   let %x_1: Tensor[(1, 2, 2, 4), int8] = %x_0(%input);
+ *   %x_1
+ * }
+ *
+ * Note that in the simple example above io_used_memory and used_memory are the same since there
+ * is only one primitive function.
+ */
+class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator {
+ public:
+  AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg,
+                            const transform::LivenessAnalysis& lva)
+      : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {}
+
+  /*!
+   * \brief Mutates the input function. In addition, an "io_used_memory" annotation is
+   * added to the input function which refers to the total size required for the IO
+   * tensors.
+   */
+  Function operator()(const Function& func) {
+    uint64_t io_used_memory = 0;
+
+    // Inputs
+    for (const Var& param : func->params) {
+      Type type = param->checked_type();
+      ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+      io_used_memory += CalculateRelayExprSizeBytes(type);
+    }
+
+    // Outputs
+    Type type = func->body->checked_type();
+    ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+    io_used_memory += CalculateRelayExprSizeBytes(type);
+
+    Expr new_func_body = VisitExpr(func->body);
+    Function new_func = WithFields(func, func->params, new_func_body);
+    return WithAttr(std::move(new_func), "io_used_memory",
+                    tvm::IntImm(tvm::DataType::UInt(64), io_used_memory));
+  }
+
+  /*!
+   * \brief Establish which let bindings have primitive function values.
+   */
+  std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value) {
+    if (const auto* func_node = value.as<FunctionNode>()) {
+      ICHECK(func_node->attrs.HasNonzeroAttr(attr::kPrimitive))
+          << "Expect top-level functions to be primitive.";
+      let_bound_prim_func_.insert(var);
+    }
+    return DeviceAwareExprMutator::PreVisitLetBinding_(var, value);
+  }
+
+  /*!
+   * \brief Visit let nodes and perform one of two actions depending on their value:
+   *
+   * 1. CallNode - Calculate "used_memory" annotation value at the callsite of
+   *               primitive functions.
+   *
+   * 2. FunctionNode - Annotate functions with "used_memory" annotation based on the
+   *                   previous analysis at the callsite.
+   *
+   */
+  Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node) override {
+    Var let_var = post_let_node->var;
+    Expr let_value = IgnoreOnDevice(post_let_node->value);
+
+    if (let_value->IsInstance<CallNode>()) {
+      Call callsite = Downcast<Call>(let_value);
+      if (CheckPrimitiveFunctionCall(callsite)) {
+        Var call_op = Downcast<Var>(callsite->op);
+
+        // Find all the vars that are live at the callsite. This is done by merging the
+        // in and out varset's and then removing the var that references the primitive
+        // function itself since we don't want this included in the calculation.
+        const transform::ControlFlowGraph::NodePtr cfg_node =
+            control_flow_graph_.let_map.at(GetRef<Let>(pre_let_node));
+        transform::VarSet live_tensors = liveness_.live_in.at(cfg_node);
+        const transform::VarSet& live_out = liveness_.live_out.at(cfg_node);
+        live_tensors.insert(live_out.begin(), live_out.end());
+        live_tensors.erase(call_op);
+
+        // Calculate size of live tensors and store to allow annotation when the function
+        // gets visited.
+        uint64_t used_memory = 0;
+        for (const auto& var : live_tensors) {
+          Type type = var->checked_type();
+          ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+          used_memory += CalculateRelayExprSizeBytes(type);
+        }
+        used_memory_annotations_[call_op] = used_memory;

Review Comment:
   List seems fine as well -- we can consume by taking the max of it for our usecase.
   
   @lhutton1 WDYT?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] manupa-arm commented on a diff in pull request #11208: [AOT] Calculate used memory at the callsite of primitive functions

Posted by GitBox <gi...@apache.org>.
manupa-arm commented on code in PR #11208:
URL: https://github.com/apache/tvm/pull/11208#discussion_r877258645


##########
src/relay/backend/aot_executor_codegen.cc:
##########
@@ -1063,6 +1064,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
     }
 
     mod = transform::ToANormalForm()(mod);
+    mod = transform::InferType()(mod);
+    mod = transform::AnnotateUsedMemory()(mod);

Review Comment:
   This is mainly to get an idea how aggressive the scheduling need to be in-terms of memory when we lower them to TIR.
   
   Yes, USMP will analyze and optimize further post-scheduling attempting to hit "a" theoretical minimum (i.e. memory pressure).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] lhutton1 commented on a diff in pull request #11208: [AOT] Calculate used memory at the callsite of primitive functions

Posted by GitBox <gi...@apache.org>.
lhutton1 commented on code in PR #11208:
URL: https://github.com/apache/tvm/pull/11208#discussion_r878085445


##########
src/relay/backend/aot/annotate_used_memory.cc:
##########
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/aot/annotate_used_memory.cc
+ * \brief Analyzes the memory pressure at the callsite of primitive functions.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/relay/transform.h>
+
+#include "../../transforms/device_aware_visitors.h"
+#include "../manifest_lifetimes.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+namespace aot {
+
+/*!
+ * \brief Annotates the memory usage of each primitive function by analysing the liveness
+ * of the input/output tensors at the function callsite and calculating the total amount of
+ * memory these tensors require.
+ */
+class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator {
+ public:
+  AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg,
+                            const transform::LivenessAnalysis& lva)
+      : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {}
+
+  /*!
+   * \brief Get the memory required for a primitive Relay function by calculating the total
+   * bytes of the live tensors at the callsite of the function.
+   *
+   * \param live_tensors The tensors that are live when the function is called.
+   * \return int The total number of bytes a function requires.
+   */
+  int GetMemoryUsage(const transform::VarSet& live_tensors) {
+    Array<Type> types_stack = {};
+    int memory_usage = 0;
+
+    for (const Var& var : live_tensors) {
+      Type var_type = var->checked_type();
+      ICHECK(var_type.defined()) << "InferTypes pass should be run before AnnotateUsedMemory pass.";
+      types_stack.push_back(var_type);
+    }
+
+    while (!types_stack.empty()) {
+      Type current_type = types_stack.back();
+      types_stack.pop_back();
+
+      if (const auto* tt_node = current_type.as<TupleTypeNode>()) {
+        for (const Type& type : tt_node->fields) {
+          types_stack.push_back(type);
+        }
+        continue;
+      } else if (const auto* ft_node = current_type.as<FuncTypeNode>()) {
+        types_stack.push_back(ft_node->ret_type);

Review Comment:
   Yeah this is due to currently visiting the function, when it should really be the call node - which would mean this isn't required when its changed



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] areusch commented on a diff in pull request #11208: [AOT] Calculate used memory at the callsite of primitive functions

Posted by GitBox <gi...@apache.org>.
areusch commented on code in PR #11208:
URL: https://github.com/apache/tvm/pull/11208#discussion_r879752148


##########
src/relay/backend/aot_executor_codegen.cc:
##########
@@ -1063,6 +1064,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
     }
 
     mod = transform::ToANormalForm()(mod);
+    mod = transform::InferType()(mod);
+    mod = transform::AnnotateUsedMemory()(mod);

Review Comment:
   got it, so is this intended as cascader input?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] altanh commented on a diff in pull request #11208: [AOT] Calculate used memory at the callsite of primitive functions

Posted by GitBox <gi...@apache.org>.
altanh commented on code in PR #11208:
URL: https://github.com/apache/tvm/pull/11208#discussion_r886194345


##########
src/relay/backend/annotate_used_memory.cc:
##########
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/annotate_used_memory.cc
+ * \brief Analyzes the used memory at the callsite of primitive functions.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/transform.h>
+
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../transforms/device_aware_visitors.h"
+#include "./liveness_analysis.h"
+#include "./utils.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+
+/*!
+ * \brief Annotates the memory usage of each primitive function by analyzing the liveness
+ * of the input/output tensors at each function callsite and calculating the total amount of
+ * memory these tensors require. This is added as a "used_memory" annotation to the function
+ * in question. In addition, the containing function is annotated with an "io_used_memory"
+ * annotation which refers to the total memory required for the IO tensors.
+ *
+ * A simple example:
+ *
+ * Before:
+ * def @main(%input: Tensor[(1, 2, 2, 4), int8]) -> Tensor[(1, 2, 2, 4), int8] {
+ *   let %x_0 = fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1) -> Tensor[(1, 2, 2, 4), int8] {
+ *     nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0])
+ *   };
+ *   let %x_1 = %x_0(%input);
+ *   %x_1
+ * }
+ *
+ * After:
+ * def @main(%input: Tensor[(1, 2, 2, 4), int8], io_used_memory=32) -> Tensor[(1, 2, 2, 4), int8] {
+ *   let %x_0: fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1, used_memory=32) -> Tensor[(1, 2, 2,
+ * 4), int8] { nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0])
+ *   };
+ *   let %x_1: Tensor[(1, 2, 2, 4), int8] = %x_0(%input);
+ *   %x_1
+ * }
+ *
+ * Note that in the simple example above io_used_memory and used_memory are the same since there
+ * is only one primitive function.
+ */
+class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator {
+ public:
+  AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg,
+                            const transform::LivenessAnalysis& lva)
+      : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {}
+
+  /*!
+   * \brief Mutates the input function. In addition, an "io_used_memory" annotation is
+   * added to the input function which refers to the total size required for the IO
+   * tensors.
+   */
+  Function operator()(const Function& func) {
+    uint64_t io_used_memory = 0;
+
+    // Inputs
+    for (const Var& param : func->params) {
+      Type type = param->checked_type();
+      ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+      io_used_memory += CalculateRelayExprSizeBytes(type);
+    }
+
+    // Outputs
+    Type type = func->body->checked_type();
+    ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+    io_used_memory += CalculateRelayExprSizeBytes(type);
+
+    Expr new_func_body = VisitExpr(func->body);
+    Function new_func = WithFields(func, func->params, new_func_body);
+    return WithAttr(std::move(new_func), "io_used_memory",
+                    tvm::IntImm(tvm::DataType::UInt(64), io_used_memory));
+  }
+
+  /*!
+   * \brief Establish which let bindings have primitive function values.
+   */
+  std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value) {
+    if (const auto* func_node = value.as<FunctionNode>()) {
+      ICHECK(func_node->attrs.HasNonzeroAttr(attr::kPrimitive))
+          << "Expect top-level functions to be primitive.";
+      let_bound_prim_func_.insert(var);
+    }
+    return DeviceAwareExprMutator::PreVisitLetBinding_(var, value);
+  }
+
+  /*!
+   * \brief Visit let nodes and perform one of two actions depending on their value:
+   *
+   * 1. CallNode - Calculate "used_memory" annotation value at the callsite of
+   *               primitive functions.
+   *
+   * 2. FunctionNode - Annotate functions with "used_memory" annotation based on the
+   *                   previous analysis at the callsite.
+   *
+   */
+  Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node) override {
+    Var let_var = post_let_node->var;
+    Expr let_value = IgnoreOnDevice(post_let_node->value);
+
+    if (let_value->IsInstance<CallNode>()) {
+      Call callsite = Downcast<Call>(let_value);
+      if (CheckPrimitiveFunctionCall(callsite)) {
+        Var call_op = Downcast<Var>(callsite->op);
+
+        // Find all the vars that are live at the callsite. This is done by merging the
+        // in and out varset's and then removing the var that references the primitive
+        // function itself since we don't want this included in the calculation.
+        const transform::ControlFlowGraph::NodePtr cfg_node =
+            control_flow_graph_.let_map.at(GetRef<Let>(pre_let_node));
+        transform::VarSet live_tensors = liveness_.live_in.at(cfg_node);
+        const transform::VarSet& live_out = liveness_.live_out.at(cfg_node);
+        live_tensors.insert(live_out.begin(), live_out.end());
+        live_tensors.erase(call_op);
+
+        // Calculate size of live tensors and store to allow annotation when the function
+        // gets visited.
+        uint64_t used_memory = 0;
+        for (const auto& var : live_tensors) {
+          Type type = var->checked_type();
+          ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+          used_memory += CalculateRelayExprSizeBytes(type);

Review Comment:
   I was hoping to avoid a hard failure like this, but as long as we can guarantee this pass doesn't get run when there are dynamic shapes in the model (which is valid in normal Relay-land), then I'm not too opposed. Can you confirm this won't be going in the normal pass flow?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] lhutton1 commented on a diff in pull request #11208: [AOT] Calculate used memory at the callsite of primitive functions

Posted by GitBox <gi...@apache.org>.
lhutton1 commented on code in PR #11208:
URL: https://github.com/apache/tvm/pull/11208#discussion_r886648892


##########
src/relay/backend/annotate_used_memory.cc:
##########
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/annotate_used_memory.cc
+ * \brief Analyzes the used memory at the callsite of primitive functions.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/transform.h>
+
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../transforms/device_aware_visitors.h"
+#include "./liveness_analysis.h"
+#include "./utils.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+
+/*!
+ * \brief Annotates the memory usage of each primitive function by analyzing the liveness
+ * of the input/output tensors at each function callsite and calculating the total amount of
+ * memory these tensors require. This is added as a "used_memory" annotation to the function
+ * in question. In addition, the containing function is annotated with an "io_used_memory"
+ * annotation which refers to the total memory required for the IO tensors.
+ *
+ * A simple example:
+ *
+ * Before:
+ * def @main(%input: Tensor[(1, 2, 2, 4), int8]) -> Tensor[(1, 2, 2, 4), int8] {
+ *   let %x_0 = fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1) -> Tensor[(1, 2, 2, 4), int8] {
+ *     nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0])
+ *   };
+ *   let %x_1 = %x_0(%input);
+ *   %x_1
+ * }
+ *
+ * After:
+ * def @main(%input: Tensor[(1, 2, 2, 4), int8], io_used_memory=32) -> Tensor[(1, 2, 2, 4), int8] {
+ *   let %x_0: fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1, used_memory=32) -> Tensor[(1, 2, 2,
+ * 4), int8] { nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0])
+ *   };
+ *   let %x_1: Tensor[(1, 2, 2, 4), int8] = %x_0(%input);
+ *   %x_1
+ * }
+ *
+ * Note that in the simple example above io_used_memory and used_memory are the same since there
+ * is only one primitive function.
+ */
+class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator {
+ public:
+  AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg,
+                            const transform::LivenessAnalysis& lva)
+      : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {}
+
+  /*!
+   * \brief Mutates the input function. In addition, an "io_used_memory" annotation is
+   * added to the input function which refers to the total size required for the IO
+   * tensors.
+   */
+  Function operator()(const Function& func) {
+    uint64_t io_used_memory = 0;
+
+    // Inputs
+    for (const Var& param : func->params) {
+      Type type = param->checked_type();
+      ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+      io_used_memory += CalculateRelayExprSizeBytes(type);
+    }
+
+    // Outputs
+    Type type = func->body->checked_type();
+    ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+    io_used_memory += CalculateRelayExprSizeBytes(type);
+
+    Expr new_func_body = VisitExpr(func->body);
+    Function new_func = WithFields(func, func->params, new_func_body);
+    return WithAttr(std::move(new_func), "io_used_memory",
+                    tvm::IntImm(tvm::DataType::UInt(64), io_used_memory));
+  }
+
+  /*!
+   * \brief Establish which let bindings have primitive function values.
+   */
+  std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value) {
+    if (const auto* func_node = value.as<FunctionNode>()) {
+      ICHECK(func_node->attrs.HasNonzeroAttr(attr::kPrimitive))
+          << "Expect top-level functions to be primitive.";
+      let_bound_prim_func_.insert(var);
+    }
+    return DeviceAwareExprMutator::PreVisitLetBinding_(var, value);
+  }
+
+  /*!
+   * \brief Visit let nodes and perform one of two actions depending on their value:
+   *
+   * 1. CallNode - Calculate "used_memory" annotation value at the callsite of
+   *               primitive functions.
+   *
+   * 2. FunctionNode - Annotate functions with "used_memory" annotation based on the
+   *                   previous analysis at the callsite.
+   *
+   */
+  Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node) override {
+    Var let_var = post_let_node->var;
+    Expr let_value = IgnoreOnDevice(post_let_node->value);
+
+    if (let_value->IsInstance<CallNode>()) {
+      Call callsite = Downcast<Call>(let_value);
+      if (CheckPrimitiveFunctionCall(callsite)) {
+        Var call_op = Downcast<Var>(callsite->op);
+
+        // Find all the vars that are live at the callsite. This is done by merging the
+        // in and out varset's and then removing the var that references the primitive
+        // function itself since we don't want this included in the calculation.
+        const transform::ControlFlowGraph::NodePtr cfg_node =
+            control_flow_graph_.let_map.at(GetRef<Let>(pre_let_node));
+        transform::VarSet live_tensors = liveness_.live_in.at(cfg_node);
+        const transform::VarSet& live_out = liveness_.live_out.at(cfg_node);
+        live_tensors.insert(live_out.begin(), live_out.end());
+        live_tensors.erase(call_op);
+
+        // Calculate size of live tensors and store to allow annotation when the function
+        // gets visited.
+        uint64_t used_memory = 0;
+        for (const auto& var : live_tensors) {
+          Type type = var->checked_type();
+          ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+          used_memory += CalculateRelayExprSizeBytes(type);

Review Comment:
   Ah apologies, I see what you mean. Since we only intend to use this pass with the AOT executor, which as far as I understand doesn't support dynamic shapes, I think its okay for this pass to be unchecked from where it gets applied. I think adding a comment to the pass description might be helpful for anyone else using the pass though, stating something along the lines of: `Note: This pass does not support dynamic shapes, it is the users responsibility to check this pass isn't applied where dynamic shapes may be input`, WDYT?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] lhutton1 commented on a diff in pull request #11208: [AOT] Calculate used memory at the callsite of primitive functions

Posted by GitBox <gi...@apache.org>.
lhutton1 commented on code in PR #11208:
URL: https://github.com/apache/tvm/pull/11208#discussion_r885076426


##########
src/relay/backend/annotate_used_memory.cc:
##########
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/annotate_used_memory.cc
+ * \brief Analyzes the used memory at the callsite of primitive functions.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/transform.h>
+
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../transforms/device_aware_visitors.h"
+#include "./liveness_analysis.h"
+#include "./utils.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+
+/*!
+ * \brief Annotates the memory usage of each primitive function by analyzing the liveness
+ * of the input/output tensors at each function callsite and calculating the total amount of
+ * memory these tensors require. This is added as a "used_memory" annotation to the function
+ * in question. In addition, the containing function is annotated with an "io_used_memory"
+ * annotation which refers to the total memory required for the IO tensors.
+ *
+ * A simple example:
+ *
+ * Before:
+ * def @main(%input: Tensor[(1, 2, 2, 4), int8]) -> Tensor[(1, 2, 2, 4), int8] {
+ *   let %x_0 = fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1) -> Tensor[(1, 2, 2, 4), int8] {
+ *     nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0])
+ *   };
+ *   let %x_1 = %x_0(%input);
+ *   %x_1
+ * }
+ *
+ * After:
+ * def @main(%input: Tensor[(1, 2, 2, 4), int8], io_used_memory=32) -> Tensor[(1, 2, 2, 4), int8] {
+ *   let %x_0: fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1, used_memory=32) -> Tensor[(1, 2, 2,
+ * 4), int8] { nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0])
+ *   };
+ *   let %x_1: Tensor[(1, 2, 2, 4), int8] = %x_0(%input);
+ *   %x_1
+ * }
+ *
+ * Note that in the simple example above io_used_memory and used_memory are the same since there
+ * is only one primitive function.
+ */
+class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator {
+ public:
+  AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg,
+                            const transform::LivenessAnalysis& lva)
+      : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {}
+
+  /*!
+   * \brief Mutates the input function. In addition, an "io_used_memory" annotation is
+   * added to the input function which refers to the total size required for the IO
+   * tensors.
+   */
+  Function operator()(const Function& func) {
+    uint64_t io_used_memory = 0;
+
+    // Inputs
+    for (const Var& param : func->params) {
+      Type type = param->checked_type();
+      ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+      io_used_memory += CalculateRelayExprSizeBytes(type);
+    }
+
+    // Outputs
+    Type type = func->body->checked_type();
+    ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+    io_used_memory += CalculateRelayExprSizeBytes(type);
+
+    Expr new_func_body = VisitExpr(func->body);
+    Function new_func = WithFields(func, func->params, new_func_body);
+    return WithAttr(std::move(new_func), "io_used_memory",
+                    tvm::IntImm(tvm::DataType::UInt(64), io_used_memory));
+  }
+
+  /*!
+   * \brief Establish which let bindings have primitive function values.
+   */
+  std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value) {
+    if (const auto* func_node = value.as<FunctionNode>()) {
+      ICHECK(func_node->attrs.HasNonzeroAttr(attr::kPrimitive))
+          << "Expect top-level functions to be primitive.";
+      let_bound_prim_func_.insert(var);
+    }
+    return DeviceAwareExprMutator::PreVisitLetBinding_(var, value);
+  }
+
+  /*!
+   * \brief Visit let nodes and perform one of two actions depending on their value:
+   *
+   * 1. CallNode - Calculate "used_memory" annotation value at the callsite of
+   *               primitive functions.
+   *
+   * 2. FunctionNode - Annotate functions with "used_memory" annotation based on the
+   *                   previous analysis at the callsite.
+   *
+   */
+  Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node) override {
+    Var let_var = post_let_node->var;
+    Expr let_value = IgnoreOnDevice(post_let_node->value);
+
+    if (let_value->IsInstance<CallNode>()) {
+      Call callsite = Downcast<Call>(let_value);
+      if (CheckPrimitiveFunctionCall(callsite)) {
+        Var call_op = Downcast<Var>(callsite->op);
+
+        // Find all the vars that are live at the callsite. This is done by merging the
+        // in and out varset's and then removing the var that references the primitive
+        // function itself since we don't want this included in the calculation.
+        const transform::ControlFlowGraph::NodePtr cfg_node =
+            control_flow_graph_.let_map.at(GetRef<Let>(pre_let_node));
+        transform::VarSet live_tensors = liveness_.live_in.at(cfg_node);

Review Comment:
   Added a case here https://github.com/apache/tvm/pull/11208/files#diff-9eda68677259669b60c638c123a7a1d2437a86e1121f8d240581aef3c30492eeR356 -- I'm hoping I understood correctly?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] lhutton1 commented on a diff in pull request #11208: [AOT] Calculate used memory at the callsite of primitive functions

Posted by GitBox <gi...@apache.org>.
lhutton1 commented on code in PR #11208:
URL: https://github.com/apache/tvm/pull/11208#discussion_r885075767


##########
src/relay/backend/annotate_used_memory.cc:
##########
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/annotate_used_memory.cc
+ * \brief Analyzes the used memory at the callsite of primitive functions.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/transform.h>
+
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../transforms/device_aware_visitors.h"
+#include "./liveness_analysis.h"
+#include "./utils.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+
+/*!
+ * \brief Annotates the memory usage of each primitive function by analyzing the liveness

Review Comment:
   Thanks, I've updated, hopefully it makes more sense



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] lhutton1 commented on a diff in pull request #11208: [AOT] Calculate used memory at the callsite of primitive functions

Posted by GitBox <gi...@apache.org>.
lhutton1 commented on code in PR #11208:
URL: https://github.com/apache/tvm/pull/11208#discussion_r886655551


##########
src/relay/backend/annotate_used_memory.cc:
##########
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/annotate_used_memory.cc
+ * \brief Analyzes the used memory at the callsite of primitive functions.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/transform.h>
+
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../transforms/device_aware_visitors.h"
+#include "./liveness_analysis.h"
+#include "./utils.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+
+/*!
+ * \brief Annotates the memory usage of each primitive function by analyzing the liveness
+ * of the input/output tensors at each function callsite and calculating the total amount of
+ * memory these tensors require. This is added as a "used_memory" annotation to the function
+ * in question. In addition, the containing function is annotated with an "io_used_memory"
+ * annotation which refers to the total memory required for the IO tensors.
+ *
+ * A simple example:
+ *
+ * Before:
+ * def @main(%input: Tensor[(1, 2, 2, 4), int8]) -> Tensor[(1, 2, 2, 4), int8] {
+ *   let %x_0 = fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1) -> Tensor[(1, 2, 2, 4), int8] {
+ *     nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0])
+ *   };
+ *   let %x_1 = %x_0(%input);
+ *   %x_1
+ * }
+ *
+ * After:
+ * def @main(%input: Tensor[(1, 2, 2, 4), int8], io_used_memory=32) -> Tensor[(1, 2, 2, 4), int8] {
+ *   let %x_0: fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1, used_memory=32) -> Tensor[(1, 2, 2,
+ * 4), int8] { nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0])
+ *   };
+ *   let %x_1: Tensor[(1, 2, 2, 4), int8] = %x_0(%input);
+ *   %x_1
+ * }
+ *
+ * Note that in the simple example above io_used_memory and used_memory are the same since there
+ * is only one primitive function.
+ */
+class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator {
+ public:
+  AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg,
+                            const transform::LivenessAnalysis& lva)
+      : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {}
+
+  /*!
+   * \brief Mutates the input function. In addition, an "io_used_memory" annotation is
+   * added to the input function which refers to the total size required for the IO
+   * tensors.
+   */
+  Function operator()(const Function& func) {
+    uint64_t io_used_memory = 0;
+
+    // Inputs
+    for (const Var& param : func->params) {
+      Type type = param->checked_type();
+      ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+      io_used_memory += CalculateRelayExprSizeBytes(type);
+    }
+
+    // Outputs
+    Type type = func->body->checked_type();
+    ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+    io_used_memory += CalculateRelayExprSizeBytes(type);
+
+    Expr new_func_body = VisitExpr(func->body);
+    Function new_func = WithFields(func, func->params, new_func_body);
+    return WithAttr(std::move(new_func), "io_used_memory",
+                    tvm::IntImm(tvm::DataType::UInt(64), io_used_memory));
+  }
+
+  /*!
+   * \brief Establish which let bindings have primitive function values.
+   */
+  std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value) {
+    if (const auto* func_node = value.as<FunctionNode>()) {
+      ICHECK(func_node->attrs.HasNonzeroAttr(attr::kPrimitive))
+          << "Expect top-level functions to be primitive.";
+      let_bound_prim_func_.insert(var);
+    }
+    return DeviceAwareExprMutator::PreVisitLetBinding_(var, value);
+  }
+
+  /*!
+   * \brief Visit let nodes and perform one of two actions depending on their value:
+   *
+   * 1. CallNode - Calculate "used_memory" annotation value at the callsite of
+   *               primitive functions.
+   *
+   * 2. FunctionNode - Annotate functions with "used_memory" annotation based on the
+   *                   previous analysis at the callsite.
+   *
+   */
+  Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node) override {
+    Var let_var = post_let_node->var;
+    Expr let_value = IgnoreOnDevice(post_let_node->value);
+
+    if (let_value->IsInstance<CallNode>()) {
+      Call callsite = Downcast<Call>(let_value);
+      if (CheckPrimitiveFunctionCall(callsite)) {
+        Var call_op = Downcast<Var>(callsite->op);
+
+        // Find all the vars that are live at the callsite. This is done by merging the
+        // in and out varset's and then removing the var that references the primitive
+        // function itself since we don't want this included in the calculation.
+        const transform::ControlFlowGraph::NodePtr cfg_node =
+            control_flow_graph_.let_map.at(GetRef<Let>(pre_let_node));
+        transform::VarSet live_tensors = liveness_.live_in.at(cfg_node);

Review Comment:
   I see thanks, in that case, yes the unit tests should cover this, otherwise I believe we would see mismatches between the hard-coded expected values and the result of running the pass which relies on the ordering to be correct. If the result of ANF was non-deterministic we would see these tests sometimes failing which hasn't been the case in my experience.
   
   In terms of our use-case, your thinking is correct, we would simply look to take the max of all the used_memory annotations on a function so ordering isn't really a problem for us



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] areusch commented on pull request #11208: [AOT] Calculate used memory at the callsite of primitive functions

Posted by GitBox <gi...@apache.org>.
areusch commented on PR #11208:
URL: https://github.com/apache/tvm/pull/11208#issuecomment-1134992902

   thanks @lhutton1 for the input! ping us when this is ready for review again.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] zhaoyang-star commented on pull request #11208: [AOT] Calculate used memory at the callsite of primitive functions

Posted by GitBox <gi...@apache.org>.
zhaoyang-star commented on PR #11208:
URL: https://github.com/apache/tvm/pull/11208#issuecomment-1280220884

   > Hi @zhaoyang-star, thanks for taking a look, its great to see this pass being used elsewhere. The pass currently expects the input to be a module of primitive functions so I would suggest running `AnnotateUsedMemory` after `FuseOps` similar to:
   > 
   > ```
   > mod = relay.transform.InferType()(mod)
   > mod = relay.transform.FuseOps()(mod)
   > mod = relay.transform.InferType()(mod)
   > mod = relay.transform.ToANormalForm()(mod)
   > mod = relay.transform.InferType()(mod)
   > mod = AnnotateUsedMemory()(mod)
   > ```
   > 
   > I did try running your example locally with the above change and this produced the relevant `used_memory` annotations. However, it looks like there is an issue while building the module after having run the `AnnotateUsedMemory` pass. Without digging too much into it I would suspect it's because this pass wasn't considered for the graph executor; only for the AOT executor. I believe changes similar to #11091 would be needed in the graph executor to support A-normal form. Hope this helps :)
   
   I want to confirm: Did you reproduce the issue( no `used_memory` attr in the output log) using my script above? If you ran all right, could you please share your script? There is only one `io_used_memory` attr and no `used_memory` attr found after running my script.
   
   If I placed the FuseOps before AnnotateUsedMemory just as you showed, there is a error `Check failed: (tensor_type) is false:`. You have mentioned maybe we should support ANF in graph executor to solve the error.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] manupa-arm commented on pull request #11208: [AOT] Calculate used memory at the callsite of primitive functions

Posted by GitBox <gi...@apache.org>.
manupa-arm commented on PR #11208:
URL: https://github.com/apache/tvm/pull/11208#issuecomment-1165726828

   @lhutton1 since it has been 18 days, should we re-run a round of CI -- just to be sure :)


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] manupa-arm commented on a diff in pull request #11208: [AOT] Calculate used memory at the callsite of primitive functions

Posted by GitBox <gi...@apache.org>.
manupa-arm commented on code in PR #11208:
URL: https://github.com/apache/tvm/pull/11208#discussion_r875009499


##########
src/relay/backend/aot/annotate_used_memory.cc:
##########
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/aot/annotate_used_memory.cc
+ * \brief Analyzes the memory pressure at the callsite of primitive functions.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/relay/transform.h>
+
+#include "../../transforms/device_aware_visitors.h"
+#include "../manifest_lifetimes.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+namespace aot {
+
+/*!
+ * \brief Annotates the memory usage of each primitive function by analysing the liveness
+ * of the input/output tensors at the function callsite and calculating the total amount of
+ * memory these tensors require.
+ */
+class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator {
+ public:
+  AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg,
+                            const transform::LivenessAnalysis& lva)
+      : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {}
+
+  /*!
+   * \brief Get the memory required for a primitive Relay function by calculating the total
+   * bytes of the live tensors at the callsite of the function.
+   *
+   * \param live_tensors The tensors that are live when the function is called.
+   * \return int The total number of bytes a function requires.
+   */
+  int GetMemoryUsage(const transform::VarSet& live_tensors) {
+    Array<Type> types_stack = {};
+    int memory_usage = 0;
+
+    for (const Var& var : live_tensors) {
+      Type var_type = var->checked_type();
+      ICHECK(var_type.defined()) << "InferTypes pass should be run before AnnotateUsedMemory pass.";
+      types_stack.push_back(var_type);
+    }
+
+    while (!types_stack.empty()) {

Review Comment:
   Is there a reason to do depth-first traversal in the types ? 
   (As opposed to expanding the types in flat manner to get the bytes)



##########
tests/python/relay/aot/test_used_memory_annotator.py:
##########
@@ -0,0 +1,194 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+
+"""
+Testing for the pass that annotates used memory for each primitive
+Relay function.
+"""
+
+import tvm
+from tvm import relay
+from tvm.relay.expr_functor import ExprVisitor
+
+
+def AnnotateUsedMemory():
+    return relay.transform._ffi_api.AnnotateUsedMemory()
+
+
+class CheckUsedMemoryAnnotation(ExprVisitor):
+    """
+    Check that the annotations on each function in the graph match
+    what is expected.
+    """
+
+    def __init__(self, expected_annotations):
+        self.expected_annotations = expected_annotations
+        super().__init__()
+
+    def visit_function(self, fn):
+        if "Primitive" in fn.attrs:
+            assert (
+                "used_memory" in fn.attrs
+            ), "Primitive function does not have used_memory annotation."
+
+            assert len(self.expected_annotations) > 0, "Not all expected annotations were compared"
+
+            expected_mem = self.expected_annotations.pop(0)
+            actual_mem = fn.attrs["used_memory"]
+            assert expected_mem == actual_mem, (
+                f"Expected used memory annotation {expected_mem} "
+                f"did not match actual annotation {actual_mem}"
+            )
+        super().visit_function(fn)
+
+
+def _check_used_memory_annotations(mod, expected_annotations):
+    mod = relay.transform.InferType()(mod)
+    mod = relay.transform.ToANormalForm()(mod)
+    mod = relay.transform.InferType()(mod)
+    mod = AnnotateUsedMemory()(mod)
+
+    CheckUsedMemoryAnnotation(expected_annotations).visit(mod["main"].body)
+
+
+def _create_primitive_function(expr):
+    func = relay.Function(relay.analysis.free_vars(expr), expr)
+    func = func.with_attr("Primitive", 1)
+    return func
+
+
+def test_simple():
+    """
+    Test simple graph with one primitive function.
+    """
+
+    def get_inner_func():
+        x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8")
+        x = relay.nn.max_pool2d(x)
+        x = _create_primitive_function(x)
+        return x
+
+    ifm = relay.var("input", shape=(1, 2, 2, 4), dtype="int8")
+    call = relay.Call(get_inner_func(), [ifm])
+    mod = tvm.IRModule.from_expr(call)
+
+    expected_annotations = [2 * (1 * 2 * 2 * 4)]
+    _check_used_memory_annotations(mod, expected_annotations)
+
+
+def test_multiple_functions():
+    """
+    Test a graph with multiple primitive functions.
+    """
+
+    def get_inner_func(ifm_shape):
+        x = relay.var("x", shape=ifm_shape, dtype="int8")
+        x = relay.nn.max_pool2d(x, pool_size=(2, 2), layout="NHWC")
+        x = _create_primitive_function(x)
+        return x
+
+    ifm = relay.var("input", shape=(1, 8, 8, 2), dtype="int8")
+    x = get_inner_func((1, 8, 8, 2))
+    x = relay.Call(x, [ifm])
+    y = get_inner_func((1, 7, 7, 2))
+    y = relay.Call(y, [x])
+    z = get_inner_func((1, 6, 6, 2))
+    z = relay.Call(z, [y])
+    mod = tvm.IRModule.from_expr(z)
+
+    expected_annotations = [
+        (1 * 8 * 8 * 2) + (1 * 7 * 7 * 2),
+        (1 * 7 * 7 * 2) + (1 * 6 * 6 * 2),
+        (1 * 6 * 6 * 2) + (1 * 5 * 5 * 2),
+    ]
+    _check_used_memory_annotations(mod, expected_annotations)
+
+
+def test_mixed_data_types():
+    """
+    Test a graph with a primitive function that has mixed datatypes.
+    """
+
+    def get_inner_func():
+        x = relay.var("x", shape=(1, 2, 2, 2), dtype="int16")
+        x = relay.cast(x, dtype="uint32")
+        x = _create_primitive_function(x)
+        return x
+
+    ifm = relay.var("input", shape=(1, 2, 2, 2), dtype="int16")
+    x = get_inner_func()
+    x = relay.Call(x, [ifm])
+    mod = tvm.IRModule.from_expr(x)
+
+    expected_annotations = [
+        (1 * 2 * 2 * 2) * 2 + (1 * 2 * 2 * 2) * 4,
+    ]
+    _check_used_memory_annotations(mod, expected_annotations)
+
+
+def test_parallel_function_call():

Review Comment:
   Few suggestions for more test cases : 
   1) Nested branches
   2) Long branches (>3) where each/some branch has more than one operator.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] manupa-arm commented on pull request #11208: [AOT] Calculate used memory at the callsite of primitive functions

Posted by GitBox <gi...@apache.org>.
manupa-arm commented on PR #11208:
URL: https://github.com/apache/tvm/pull/11208#issuecomment-1166268834

   Thanks @lhutton1 @altanh @areusch ! This is merged now


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] areusch commented on pull request #11208: [AOT] Calculate used memory at the callsite of primitive functions

Posted by GitBox <gi...@apache.org>.
areusch commented on PR #11208:
URL: https://github.com/apache/tvm/pull/11208#issuecomment-1139014246

   ok thanks @lhutton1 ! i'll defer to @altanh on this one


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] manupa-arm commented on a diff in pull request #11208: [AOT] Calculate used memory at the callsite of primitive functions

Posted by GitBox <gi...@apache.org>.
manupa-arm commented on code in PR #11208:
URL: https://github.com/apache/tvm/pull/11208#discussion_r880096340


##########
src/relay/backend/aot_executor_codegen.cc:
##########
@@ -1063,6 +1064,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
     }
 
     mod = transform::ToANormalForm()(mod);
+    mod = transform::InferType()(mod);
+    mod = transform::AnnotateUsedMemory()(mod);

Review Comment:
   Yes, the cascader is the usecase we are interested in.
   
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] altanh commented on a diff in pull request #11208: [AOT] Calculate used memory at the callsite of primitive functions

Posted by GitBox <gi...@apache.org>.
altanh commented on code in PR #11208:
URL: https://github.com/apache/tvm/pull/11208#discussion_r877477353


##########
src/relay/backend/aot/annotate_used_memory.cc:
##########
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/aot/annotate_used_memory.cc
+ * \brief Analyzes the memory pressure at the callsite of primitive functions.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/relay/transform.h>
+
+#include "../../transforms/device_aware_visitors.h"
+#include "../manifest_lifetimes.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+namespace aot {
+
+/*!
+ * \brief Annotates the memory usage of each primitive function by analysing the liveness
+ * of the input/output tensors at the function callsite and calculating the total amount of
+ * memory these tensors require.
+ */
+class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator {
+ public:
+  AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg,
+                            const transform::LivenessAnalysis& lva)
+      : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {}
+
+  /*!
+   * \brief Get the memory required for a primitive Relay function by calculating the total
+   * bytes of the live tensors at the callsite of the function.
+   *
+   * \param live_tensors The tensors that are live when the function is called.
+   * \return int The total number of bytes a function requires.
+   */
+  int GetMemoryUsage(const transform::VarSet& live_tensors) {
+    Array<Type> types_stack = {};
+    int memory_usage = 0;

Review Comment:
   can we widen the type for `memory_usage`? this would overflow at ~2GB which is pretty realistic these days. maybe `uint64_t`?



##########
src/relay/backend/aot/annotate_used_memory.cc:
##########
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/aot/annotate_used_memory.cc
+ * \brief Analyzes the memory pressure at the callsite of primitive functions.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/relay/transform.h>
+
+#include "../../transforms/device_aware_visitors.h"
+#include "../manifest_lifetimes.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+namespace aot {
+
+/*!
+ * \brief Annotates the memory usage of each primitive function by analysing the liveness
+ * of the input/output tensors at the function callsite and calculating the total amount of
+ * memory these tensors require.
+ */
+class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator {
+ public:
+  AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg,
+                            const transform::LivenessAnalysis& lva)
+      : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {}
+
+  /*!
+   * \brief Get the memory required for a primitive Relay function by calculating the total
+   * bytes of the live tensors at the callsite of the function.
+   *
+   * \param live_tensors The tensors that are live when the function is called.
+   * \return int The total number of bytes a function requires.
+   */
+  int GetMemoryUsage(const transform::VarSet& live_tensors) {
+    Array<Type> types_stack = {};
+    int memory_usage = 0;
+
+    for (const Var& var : live_tensors) {
+      Type var_type = var->checked_type();
+      ICHECK(var_type.defined()) << "InferTypes pass should be run before AnnotateUsedMemory pass.";
+      types_stack.push_back(var_type);
+    }
+
+    while (!types_stack.empty()) {
+      Type current_type = types_stack.back();
+      types_stack.pop_back();
+
+      if (const auto* tt_node = current_type.as<TupleTypeNode>()) {
+        for (const Type& type : tt_node->fields) {
+          types_stack.push_back(type);
+        }
+        continue;
+      } else if (const auto* ft_node = current_type.as<FuncTypeNode>()) {
+        types_stack.push_back(ft_node->ret_type);
+        continue;
+      }
+
+      const auto* tt_node = current_type.as<TensorTypeNode>();
+      ICHECK(tt_node) << "Expected TensorTypeNode but was " << current_type->GetTypeKey();
+      int total_tensor_bytes = GetTensorBytes(tt_node);
+      memory_usage += total_tensor_bytes;
+    }
+    return memory_usage;
+  }
+
+  /*!
+   * \brief Get the number of bytes a tensor requires.
+   *
+   * \param tensor_type_node The checked type of the tensor.
+   * \return int The number of bytes required.
+   */
+  int GetTensorBytes(const TensorTypeNode* tensor_type_node) {
+    PrimExpr size = tensor_type_node->Size();
+    const auto* size_int_imm = size.as<IntImmNode>();
+    ICHECK(size_int_imm) << "Expected tensor size to be an IntImmNode but was "
+                         << size->GetTypeKey();
+
+    int total_size = size_int_imm->value;
+    int dtype_bytes = tensor_type_node->dtype.bytes();
+    return total_size * dtype_bytes;
+  }
+
+  Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node) override {
+    if (const auto* func_node = pre_let_node->value.as<FunctionNode>()) {
+      const auto let_bound_values = control_flow_graph_.let_map;

Review Comment:
   probably worth making this a reference so the map doesn't get copied unnecessarily



##########
src/relay/backend/aot/annotate_used_memory.cc:
##########
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/aot/annotate_used_memory.cc
+ * \brief Analyzes the memory pressure at the callsite of primitive functions.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/relay/transform.h>
+
+#include "../../transforms/device_aware_visitors.h"
+#include "../manifest_lifetimes.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+namespace aot {
+
+/*!
+ * \brief Annotates the memory usage of each primitive function by analysing the liveness
+ * of the input/output tensors at the function callsite and calculating the total amount of
+ * memory these tensors require.
+ */
+class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator {
+ public:
+  AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg,
+                            const transform::LivenessAnalysis& lva)
+      : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {}
+
+  /*!
+   * \brief Get the memory required for a primitive Relay function by calculating the total
+   * bytes of the live tensors at the callsite of the function.
+   *
+   * \param live_tensors The tensors that are live when the function is called.
+   * \return int The total number of bytes a function requires.
+   */
+  int GetMemoryUsage(const transform::VarSet& live_tensors) {
+    Array<Type> types_stack = {};
+    int memory_usage = 0;
+
+    for (const Var& var : live_tensors) {
+      Type var_type = var->checked_type();
+      ICHECK(var_type.defined()) << "InferTypes pass should be run before AnnotateUsedMemory pass.";
+      types_stack.push_back(var_type);
+    }
+
+    while (!types_stack.empty()) {
+      Type current_type = types_stack.back();
+      types_stack.pop_back();
+
+      if (const auto* tt_node = current_type.as<TupleTypeNode>()) {
+        for (const Type& type : tt_node->fields) {
+          types_stack.push_back(type);
+        }
+        continue;
+      } else if (const auto* ft_node = current_type.as<FuncTypeNode>()) {
+        types_stack.push_back(ft_node->ret_type);

Review Comment:
   why would a function show up here? is it a function call? seems like we should just ignore functions bound to variables since it's the actual result (e.g.`x: Tensor[...] = f_var(...)`) that uses space, not the function variable itself



##########
src/relay/backend/aot/annotate_used_memory.cc:
##########
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/aot/annotate_used_memory.cc
+ * \brief Analyzes the memory pressure at the callsite of primitive functions.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/relay/transform.h>
+
+#include "../../transforms/device_aware_visitors.h"
+#include "../manifest_lifetimes.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+namespace aot {
+
+/*!
+ * \brief Annotates the memory usage of each primitive function by analysing the liveness
+ * of the input/output tensors at the function callsite and calculating the total amount of
+ * memory these tensors require.
+ */
+class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator {
+ public:
+  AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg,
+                            const transform::LivenessAnalysis& lva)
+      : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {}
+
+  /*!
+   * \brief Get the memory required for a primitive Relay function by calculating the total
+   * bytes of the live tensors at the callsite of the function.
+   *
+   * \param live_tensors The tensors that are live when the function is called.
+   * \return int The total number of bytes a function requires.
+   */
+  int GetMemoryUsage(const transform::VarSet& live_tensors) {
+    Array<Type> types_stack = {};
+    int memory_usage = 0;
+
+    for (const Var& var : live_tensors) {
+      Type var_type = var->checked_type();
+      ICHECK(var_type.defined()) << "InferTypes pass should be run before AnnotateUsedMemory pass.";
+      types_stack.push_back(var_type);
+    }
+
+    while (!types_stack.empty()) {
+      Type current_type = types_stack.back();
+      types_stack.pop_back();
+
+      if (const auto* tt_node = current_type.as<TupleTypeNode>()) {
+        for (const Type& type : tt_node->fields) {
+          types_stack.push_back(type);
+        }
+        continue;
+      } else if (const auto* ft_node = current_type.as<FuncTypeNode>()) {
+        types_stack.push_back(ft_node->ret_type);
+        continue;
+      }
+
+      const auto* tt_node = current_type.as<TensorTypeNode>();
+      ICHECK(tt_node) << "Expected TensorTypeNode but was " << current_type->GetTypeKey();
+      int total_tensor_bytes = GetTensorBytes(tt_node);
+      memory_usage += total_tensor_bytes;
+    }
+    return memory_usage;
+  }
+
+  /*!
+   * \brief Get the number of bytes a tensor requires.
+   *
+   * \param tensor_type_node The checked type of the tensor.
+   * \return int The number of bytes required.
+   */
+  int GetTensorBytes(const TensorTypeNode* tensor_type_node) {
+    PrimExpr size = tensor_type_node->Size();
+    const auto* size_int_imm = size.as<IntImmNode>();
+    ICHECK(size_int_imm) << "Expected tensor size to be an IntImmNode but was "
+                         << size->GetTypeKey();
+
+    int total_size = size_int_imm->value;
+    int dtype_bytes = tensor_type_node->dtype.bytes();
+    return total_size * dtype_bytes;
+  }
+
+  Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node) override {
+    if (const auto* func_node = pre_let_node->value.as<FunctionNode>()) {

Review Comment:
   yeah I am also confused, since (at least from IR perspective) this binding is not a call to a (primitive?) function but just binding the function to a var. I would have expected `pre_let_node->value.as<CallNode>()` with some check that the op in the call is a func. Maybe I am missing some context for the IR form at this point? definitely would like a comment



##########
src/relay/backend/manifest_lifetimes.cc:
##########
@@ -0,0 +1,367 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/manifest_lifetimes.cc
+ * \brief Analysis and explicit manifestation of variable lifetimes. NOTE: the input IR should be in
+ * ANF and post-memory-lowering (explicit manifestation of allocations).
+ */
+
+#include "manifest_lifetimes.h"
+
+#include <list>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relay {
+namespace transform {
+
+using support::Arena;
+using VarSet = std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>;
+
+ControlFlowGraph ControlFlowGraph::Create(Arena* arena, const Expr& body) {
+  return Creator().Create(arena, body);
+}
+
+ControlFlowGraph ControlFlowGraph::Creator::Create(Arena* arena, const Expr& body) {
+  arena_ = arena;
+  cfg_.entry = BasicBlock::Make(arena);
+  VisitExpr(body, cfg_.entry);
+  return std::move(cfg_);
+}
+
+void ControlFlowGraph::Creator::Succ(BasicBlockPtr from, BasicBlockPtr to) {
+  from->succ.push_back(to);
+  to->pred.push_back(from);
+}
+
+void ControlFlowGraph::Creator::VisitExpr_(const FunctionNode* f, BasicBlockPtr parent) {
+  ICHECK(!in_func_) << "nested functions not supported by CFG analysis";
+  in_func_ = true;
+
+  // Unwrap the nested function and proceed normally.
+  if (f->HasNonzeroAttr(attr::kClosure)) {
+    ICHECK(f->body.as<FunctionNode>());
+    return VisitExpr(Downcast<Function>(f->body)->body, parent);
+  }
+
+  return VisitExpr(f->body, parent);
+}
+
+void ControlFlowGraph::Creator::VisitExpr_(const LetNode* let_node, BasicBlockPtr parent) {
+  Expr expr = GetRef<Expr>(let_node);
+
+  while (const LetNode* inner_let_node = expr.as<LetNode>()) {
+    NodePtr curr_node = Node::Make(arena_, parent, expr);
+
+    ICHECK(!cfg_.let_map.count(expr));
+    cfg_.let_map[expr] = curr_node;
+    cfg_.reverse_post_order.push_back(curr_node);
+
+    // The basic block ends upon reaching control flow, with successor blocks corresponding to the
+    // control flow branch exprs (true/false in If, and one for each clause in Match).
+    if (const IfNode* ite = AsIgnoringOnDevice<IfNode>(inner_let_node->value)) {
+      // Create the basic blocks for each branch and mark them as successors to the current block.
+      BasicBlockPtr t_block = BasicBlock::Make(arena_);
+      BasicBlockPtr f_block = BasicBlock::Make(arena_);
+      Succ(parent, t_block);
+      Succ(parent, f_block);
+
+      VisitExpr(ite->true_branch, t_block);
+      VisitExpr(ite->false_branch, f_block);
+
+      // All subsequent bindings (and/or the body expr) will be in a new basic block.
+      BasicBlockPtr next = BasicBlock::Make(arena_);
+      Succ(t_block, next);
+      Succ(f_block, next);
+      parent = next;
+    } else if (const MatchNode* match = AsIgnoringOnDevice<MatchNode>(inner_let_node->value)) {
+      // Same as above but one for each pattern.
+      std::vector<BasicBlockPtr> clause_blocks;
+      BasicBlockPtr next = BasicBlock::Make(arena_);
+      for (const Clause& clause : match->clauses) {
+        BasicBlockPtr clause_block = BasicBlock::Make(arena_);
+        Succ(parent, clause_block);
+        Succ(clause_block, next);
+        VisitExpr(clause->rhs, clause_block);
+      }
+      parent = next;
+    }
+
+    expr = inner_let_node->body;
+  }
+
+  VisitExpr(expr, parent);
+}
+
+void ControlFlowGraph::Creator::VisitExpr_(const IfNode* if_node, BasicBlockPtr parent) {
+  // TODO(@altanh): is there a way of making this work?
+  LOG(FATAL) << "If expressions should be bound to variables.";
+}
+
+void ControlFlowGraph::Creator::VisitExpr_(const MatchNode* match_node, BasicBlockPtr parent) {
+  // TODO(@altanh): same as If
+  LOG(FATAL) << "Match expressions should be bound to variables.";
+}
+
+VarSet VarUseCollector::VisitExpr_(const VarNode* var_node) { return {GetRef<Var>(var_node)}; }
+
+VarSet VarUseCollector::VisitExpr_(const CallNode* call_node) {
+  VarSet use = VisitExpr(call_node->op);
+  for (const Expr& arg : call_node->args) {
+    VarSet arg_use = VisitExpr(arg);
+    use.insert(arg_use.begin(), arg_use.end());
+  }
+  return use;
+}
+
+VarSet VarUseCollector::VisitExpr_(const TupleNode* tuple_node) {
+  VarSet use;
+  for (const Expr& field : tuple_node->fields) {
+    VarSet field_use = VisitExpr(field);
+    use.insert(field_use.begin(), field_use.end());
+  }
+  return use;
+}
+
+VarSet VarUseCollector::VisitExpr_(const TupleGetItemNode* get_node) {
+  return VisitExpr(get_node->tuple);
+}
+
+VarSet VarUseCollector::VisitExpr_(const IfNode* if_node) { return VisitExpr(if_node->cond); }
+
+VarSet VarUseCollector::VisitExpr_(const MatchNode* match_node) {
+  return VisitExpr(match_node->data);
+}
+
+UseDefAnalysis UseDefAnalysis::Analyze(const CFG& cfg) {
+  UseDefAnalysis a;
+
+  // One pass is sufficient.
+  for (auto it = cfg.reverse_post_order.begin(); it != cfg.reverse_post_order.end(); ++it) {
+    const CFG::NodePtr& node = *it;
+    if (const LetNode* let_node = AsIgnoringOnDevice<LetNode>(node->expr)) {
+      a.use[node] = a.use_collector.VisitExpr(let_node->value);
+      a.def[node] = let_node->var;
+    } else {
+      a.use[node] = a.use_collector.VisitExpr(node->expr);
+      a.def[node] = Var();
+    }
+  }
+
+  return a;
+}
+
+bool SetEqual(const VarSet& a, const VarSet& b) {
+  if (a.size() != b.size()) {
+    return false;
+  }
+  for (auto& xa : a) {
+    if (!b.count(xa)) {
+      return false;
+    }
+  }
+  return true;
+}
+
+LivenessAnalysis LivenessAnalysis::Analyze(const ControlFlowGraph& cfg,
+                                           const UseDefAnalysis& use_def) {
+  LivenessAnalysis a;
+  std::list<CFG::NodePtr> worklist;
+
+  // Initialize worklist to post-order traversal for quick convergence.
+  worklist.insert(worklist.end(), cfg.reverse_post_order.rbegin(), cfg.reverse_post_order.rend());
+
+  // See https://lambda.uta.edu/cse5317/notes/node40.html for an overview of the algorithm.
+  auto visitor = [&](const CFG::NodePtr n) {
+    VarSet old_in_n = a.live_in[n];
+    VarSet old_out_n = a.live_out[n];
+
+    a.live_in[n] = use_def.use.at(n);
+    for (const Var& v : a.live_out[n]) {
+      if (!v.same_as(use_def.def.at(n))) {
+        a.live_in[n].insert(v);
+      }
+    }
+
+    a.live_out[n] = VarSet();
+    for (const CFG::NodePtr& s : n->GetSucc()) {
+      a.live_out[n].insert(a.live_in[s].begin(), a.live_in[s].end());
+    }
+
+    if (SetEqual(old_in_n, a.live_in[n]) && SetEqual(old_out_n, a.live_out[n])) {
+      // No need to update the worklist.
+    } else {
+      // Add predecessor nodes back to worklist (no need to add successors, since each node's
+      // in/out sets are not dependent on its predecessors).
+      for (const CFG::NodePtr& p : n->GetPred()) {
+        worklist.push_back(p);
+      }
+    }
+  };
+
+  while (!worklist.empty()) {
+    const CFG::NodePtr n = worklist.front();
+    worklist.pop_front();
+    visitor(n);
+  }
+
+  return a;
+}
+
+Expr KillInserter::VisitExpr_(const LetNode* let_node) {
+  Expr expr = GetRef<Expr>(let_node);
+  LetList ll;
+
+  while (const LetNode* inner_let_node = expr.as<LetNode>()) {
+    ll.Push(inner_let_node->var, VisitExpr(inner_let_node->value));
+
+    ICHECK(!inner_let_node->value.as<VarNode>()) << "aliasing should have been eliminated.";
+    ICHECK(cfg_->let_map.count(expr)) << "all Let exprs should be mapped in the CFG";
+
+    const ControlFlowGraph::NodePtr n = cfg_->let_map.at(expr);
+
+    const VarSet& li = lva_->live_in.at(n);
+    const VarSet& lo = lva_->live_out.at(n);
+
+    // Killed vars = live in - live out.
+    VarSet kills;
+    for (const Var& v : li) {
+      if (!lo.count(v)) {
+        kills.insert(v);
+      }
+    }
+
+    for (const Var& v : kills) {
+      ll.Push(Call(Op::Get("memory.kill"), {v}));
+    }
+
+    expr = inner_let_node->body;
+  }
+
+  return ll.Get(VisitExpr(expr));
+}
+
+Expr AliasEliminator::VisitExpr_(const LetNode* let_node) {
+  Expr expr = GetRef<Expr>(let_node);
+  LetList ll;
+  std::vector<Var> aliased_vars;
+
+  while (const LetNode* inner_let_node = expr.as<LetNode>()) {
+    const Var& var = inner_let_node->var;
+    const Expr& val = inner_let_node->value;
+    bool aliased = false;
+    ICHECK(!alias_.count(var));
+
+    if (const VarNode* alias_of_n = AsIgnoringOnDevice<VarNode>(val)) {
+      alias_[var] = Downcast<Var>(VisitExpr_(alias_of_n));
+      aliased = true;
+    } else if (AsIgnoringOnDevice<CallNode>(val)) {
+      // Copying to the same device is aliasing.
+      // WARNING: this must be kept in sync with the VM compiler logic in
+      // src/relay/backend/vm/compiler.cc, line 541, in DeviceAwareVisitExpr_(const CallNode*).

Review Comment:
   I feel a little weird about this move, since there's still some VM-specific logic in the stuff for manifesting lifetimes. The other stuff can probably be factored out though (CFG, and the basic analyses). 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] manupa-arm commented on a diff in pull request #11208: [AOT] Calculate used memory at the callsite of primitive functions

Posted by GitBox <gi...@apache.org>.
manupa-arm commented on code in PR #11208:
URL: https://github.com/apache/tvm/pull/11208#discussion_r877258645


##########
src/relay/backend/aot_executor_codegen.cc:
##########
@@ -1063,6 +1064,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
     }
 
     mod = transform::ToANormalForm()(mod);
+    mod = transform::InferType()(mod);
+    mod = transform::AnnotateUsedMemory()(mod);

Review Comment:
   This is mainly to get an idea how aggressive the scheduling need to be in-terms of memory when we lower them to TIR.
   
   Yes, USMP will analyze and optimize further post-scheduling attempting to be hit "a" theoretical minimum (i.e. memory pressure).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] lhutton1 commented on a diff in pull request #11208: [AOT] Calculate used memory at the callsite of primitive functions

Posted by GitBox <gi...@apache.org>.
lhutton1 commented on code in PR #11208:
URL: https://github.com/apache/tvm/pull/11208#discussion_r878091500


##########
src/relay/backend/aot/annotate_used_memory.cc:
##########
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/aot/annotate_used_memory.cc
+ * \brief Analyzes the memory pressure at the callsite of primitive functions.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/relay/transform.h>
+
+#include "../../transforms/device_aware_visitors.h"
+#include "../manifest_lifetimes.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+namespace aot {
+
+/*!
+ * \brief Annotates the memory usage of each primitive function by analysing the liveness
+ * of the input/output tensors at the function callsite and calculating the total amount of
+ * memory these tensors require.
+ */
+class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator {
+ public:
+  AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg,
+                            const transform::LivenessAnalysis& lva)
+      : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {}
+
+  /*!
+   * \brief Get the memory required for a primitive Relay function by calculating the total
+   * bytes of the live tensors at the callsite of the function.
+   *
+   * \param live_tensors The tensors that are live when the function is called.
+   * \return int The total number of bytes a function requires.
+   */
+  int GetMemoryUsage(const transform::VarSet& live_tensors) {
+    Array<Type> types_stack = {};
+    int memory_usage = 0;
+
+    for (const Var& var : live_tensors) {
+      Type var_type = var->checked_type();
+      ICHECK(var_type.defined()) << "InferTypes pass should be run before AnnotateUsedMemory pass.";
+      types_stack.push_back(var_type);
+    }
+
+    while (!types_stack.empty()) {
+      Type current_type = types_stack.back();
+      types_stack.pop_back();
+
+      if (const auto* tt_node = current_type.as<TupleTypeNode>()) {
+        for (const Type& type : tt_node->fields) {
+          types_stack.push_back(type);
+        }
+        continue;
+      } else if (const auto* ft_node = current_type.as<FuncTypeNode>()) {
+        types_stack.push_back(ft_node->ret_type);
+        continue;
+      }
+
+      const auto* tt_node = current_type.as<TensorTypeNode>();
+      ICHECK(tt_node) << "Expected TensorTypeNode but was " << current_type->GetTypeKey();
+      int total_tensor_bytes = GetTensorBytes(tt_node);
+      memory_usage += total_tensor_bytes;
+    }
+    return memory_usage;
+  }
+
+  /*!
+   * \brief Get the number of bytes a tensor requires.
+   *
+   * \param tensor_type_node The checked type of the tensor.
+   * \return int The number of bytes required.
+   */
+  int GetTensorBytes(const TensorTypeNode* tensor_type_node) {
+    PrimExpr size = tensor_type_node->Size();
+    const auto* size_int_imm = size.as<IntImmNode>();
+    ICHECK(size_int_imm) << "Expected tensor size to be an IntImmNode but was "
+                         << size->GetTypeKey();
+
+    int total_size = size_int_imm->value;
+    int dtype_bytes = tensor_type_node->dtype.bytes();
+    return total_size * dtype_bytes;
+  }
+
+  Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node) override {
+    if (const auto* func_node = pre_let_node->value.as<FunctionNode>()) {

Review Comment:
   Great point thanks, in hindsight we should definitely be visiting the callsite rather than the function. Previously I thought I could get all the information needed just visiting the function, since checking if the call node op is a function is a little more complex due to the let bindings getting in the way. But as you both rightly mention, this isn't the correct way to do it



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] areusch commented on a diff in pull request #11208: [AOT] Calculate used memory at the callsite of primitive functions

Posted by GitBox <gi...@apache.org>.
areusch commented on code in PR #11208:
URL: https://github.com/apache/tvm/pull/11208#discussion_r877247663


##########
src/relay/backend/aot/annotate_used_memory.cc:
##########
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/aot/annotate_used_memory.cc
+ * \brief Analyzes the memory pressure at the callsite of primitive functions.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/relay/transform.h>
+
+#include "../../transforms/device_aware_visitors.h"
+#include "../manifest_lifetimes.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+namespace aot {
+
+/*!
+ * \brief Annotates the memory usage of each primitive function by analysing the liveness
+ * of the input/output tensors at the function callsite and calculating the total amount of
+ * memory these tensors require.
+ */
+class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator {
+ public:
+  AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg,
+                            const transform::LivenessAnalysis& lva)
+      : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {}
+
+  /*!
+   * \brief Get the memory required for a primitive Relay function by calculating the total
+   * bytes of the live tensors at the callsite of the function.
+   *
+   * \param live_tensors The tensors that are live when the function is called.
+   * \return int The total number of bytes a function requires.
+   */
+  int GetMemoryUsage(const transform::VarSet& live_tensors) {
+    Array<Type> types_stack = {};
+    int memory_usage = 0;
+
+    for (const Var& var : live_tensors) {
+      Type var_type = var->checked_type();
+      ICHECK(var_type.defined()) << "InferTypes pass should be run before AnnotateUsedMemory pass.";
+      types_stack.push_back(var_type);
+    }
+
+    while (!types_stack.empty()) {
+      Type current_type = types_stack.back();
+      types_stack.pop_back();
+
+      if (const auto* tt_node = current_type.as<TupleTypeNode>()) {
+        for (const Type& type : tt_node->fields) {
+          types_stack.push_back(type);
+        }
+        continue;
+      } else if (const auto* ft_node = current_type.as<FuncTypeNode>()) {
+        types_stack.push_back(ft_node->ret_type);
+        continue;
+      }
+
+      const auto* tt_node = current_type.as<TensorTypeNode>();
+      ICHECK(tt_node) << "Expected TensorTypeNode but was " << current_type->GetTypeKey();
+      int total_tensor_bytes = GetTensorBytes(tt_node);
+      memory_usage += total_tensor_bytes;
+    }
+    return memory_usage;
+  }
+
+  /*!
+   * \brief Get the number of bytes a tensor requires.
+   *
+   * \param tensor_type_node The checked type of the tensor.
+   * \return int The number of bytes required.
+   */
+  int GetTensorBytes(const TensorTypeNode* tensor_type_node) {
+    PrimExpr size = tensor_type_node->Size();
+    const auto* size_int_imm = size.as<IntImmNode>();
+    ICHECK(size_int_imm) << "Expected tensor size to be an IntImmNode but was "
+                         << size->GetTypeKey();
+
+    int total_size = size_int_imm->value;
+    int dtype_bytes = tensor_type_node->dtype.bytes();
+    return total_size * dtype_bytes;
+  }
+
+  Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node) override {
+    if (const auto* func_node = pre_let_node->value.as<FunctionNode>()) {

Review Comment:
   is this matching a lowered function call? would be great if we could comment the expected construct here.



##########
src/relay/backend/aot_executor_codegen.cc:
##########
@@ -1063,6 +1064,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
     }
 
     mod = transform::ToANormalForm()(mod);
+    mod = transform::InferType()(mod);
+    mod = transform::AnnotateUsedMemory()(mod);

Review Comment:
   how come you want to do this at the relay level when we would further analyze this in USMP down below?



##########
src/relay/backend/aot/annotate_used_memory.cc:
##########
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/aot/annotate_used_memory.cc
+ * \brief Analyzes the memory pressure at the callsite of primitive functions.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/relay/transform.h>
+
+#include "../../transforms/device_aware_visitors.h"
+#include "../manifest_lifetimes.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+namespace aot {
+
+/*!
+ * \brief Annotates the memory usage of each primitive function by analysing the liveness
+ * of the input/output tensors at the function callsite and calculating the total amount of
+ * memory these tensors require.
+ */
+class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator {
+ public:
+  AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg,
+                            const transform::LivenessAnalysis& lva)
+      : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {}
+
+  /*!
+   * \brief Get the memory required for a primitive Relay function by calculating the total
+   * bytes of the live tensors at the callsite of the function.
+   *
+   * \param live_tensors The tensors that are live when the function is called.
+   * \return int The total number of bytes a function requires.
+   */
+  int GetMemoryUsage(const transform::VarSet& live_tensors) {
+    Array<Type> types_stack = {};
+    int memory_usage = 0;
+
+    for (const Var& var : live_tensors) {
+      Type var_type = var->checked_type();
+      ICHECK(var_type.defined()) << "InferTypes pass should be run before AnnotateUsedMemory pass.";
+      types_stack.push_back(var_type);
+    }
+
+    while (!types_stack.empty()) {
+      Type current_type = types_stack.back();
+      types_stack.pop_back();
+
+      if (const auto* tt_node = current_type.as<TupleTypeNode>()) {
+        for (const Type& type : tt_node->fields) {
+          types_stack.push_back(type);
+        }
+        continue;
+      } else if (const auto* ft_node = current_type.as<FuncTypeNode>()) {
+        types_stack.push_back(ft_node->ret_type);
+        continue;
+      }
+
+      const auto* tt_node = current_type.as<TensorTypeNode>();
+      ICHECK(tt_node) << "Expected TensorTypeNode but was " << current_type->GetTypeKey();
+      int total_tensor_bytes = GetTensorBytes(tt_node);
+      memory_usage += total_tensor_bytes;
+    }
+    return memory_usage;
+  }
+
+  /*!
+   * \brief Get the number of bytes a tensor requires.
+   *
+   * \param tensor_type_node The checked type of the tensor.
+   * \return int The number of bytes required.
+   */
+  int GetTensorBytes(const TensorTypeNode* tensor_type_node) {

Review Comment:
   per https://github.com/apache/tvm/blob/main/src/relay/backend/te_compiler.cc#L1002 there are many instances where we do this. could you use a utils function here instead of re-implementing?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] altanh commented on a diff in pull request #11208: [AOT] Calculate used memory at the callsite of primitive functions

Posted by GitBox <gi...@apache.org>.
altanh commented on code in PR #11208:
URL: https://github.com/apache/tvm/pull/11208#discussion_r888237836


##########
src/relay/backend/annotate_used_memory.cc:
##########
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/annotate_used_memory.cc
+ * \brief Analyzes the used memory at the callsite of primitive functions.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/transform.h>
+
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../transforms/device_aware_visitors.h"
+#include "./liveness_analysis.h"
+#include "./utils.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+
+/*!
+ * \brief Annotates the memory usage of each primitive function by analyzing the liveness
+ * of the input/output tensors at each function callsite and calculating the total amount of
+ * memory these tensors require. This is added as a "used_memory" annotation to the function
+ * in question. In addition, the containing function is annotated with an "io_used_memory"
+ * annotation which refers to the total memory required for the IO tensors.
+ *
+ * A simple example:
+ *
+ * Before:
+ * def @main(%input: Tensor[(1, 2, 2, 4), int8]) -> Tensor[(1, 2, 2, 4), int8] {
+ *   let %x_0 = fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1) -> Tensor[(1, 2, 2, 4), int8] {
+ *     nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0])
+ *   };
+ *   let %x_1 = %x_0(%input);
+ *   %x_1
+ * }
+ *
+ * After:
+ * def @main(%input: Tensor[(1, 2, 2, 4), int8], io_used_memory=32) -> Tensor[(1, 2, 2, 4), int8] {
+ *   let %x_0: fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1, used_memory=32) -> Tensor[(1, 2, 2,
+ * 4), int8] { nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0])
+ *   };
+ *   let %x_1: Tensor[(1, 2, 2, 4), int8] = %x_0(%input);
+ *   %x_1
+ * }
+ *
+ * Note that in the simple example above io_used_memory and used_memory are the same since there
+ * is only one primitive function.
+ */
+class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator {
+ public:
+  AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg,
+                            const transform::LivenessAnalysis& lva)
+      : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {}
+
+  /*!
+   * \brief Mutates the input function. In addition, an "io_used_memory" annotation is
+   * added to the input function which refers to the total size required for the IO
+   * tensors.
+   */
+  Function operator()(const Function& func) {
+    uint64_t io_used_memory = 0;
+
+    // Inputs
+    for (const Var& param : func->params) {
+      Type type = param->checked_type();
+      ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+      io_used_memory += CalculateRelayExprSizeBytes(type);
+    }
+
+    // Outputs
+    Type type = func->body->checked_type();
+    ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+    io_used_memory += CalculateRelayExprSizeBytes(type);
+
+    Expr new_func_body = VisitExpr(func->body);
+    Function new_func = WithFields(func, func->params, new_func_body);
+    return WithAttr(std::move(new_func), "io_used_memory",
+                    tvm::IntImm(tvm::DataType::UInt(64), io_used_memory));
+  }
+
+  /*!
+   * \brief Establish which let bindings have primitive function values.
+   */
+  std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value) {
+    if (const auto* func_node = value.as<FunctionNode>()) {
+      ICHECK(func_node->attrs.HasNonzeroAttr(attr::kPrimitive))
+          << "Expect top-level functions to be primitive.";
+      let_bound_prim_func_.insert(var);
+    }
+    return DeviceAwareExprMutator::PreVisitLetBinding_(var, value);
+  }
+
+  /*!
+   * \brief Visit let nodes and perform one of two actions depending on their value:
+   *
+   * 1. CallNode - Calculate "used_memory" annotation value at the callsite of
+   *               primitive functions.
+   *
+   * 2. FunctionNode - Annotate functions with "used_memory" annotation based on the
+   *                   previous analysis at the callsite.
+   *
+   */
+  Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node) override {
+    Var let_var = post_let_node->var;
+    Expr let_value = IgnoreOnDevice(post_let_node->value);
+
+    if (let_value->IsInstance<CallNode>()) {
+      Call callsite = Downcast<Call>(let_value);
+      if (CheckPrimitiveFunctionCall(callsite)) {
+        Var call_op = Downcast<Var>(callsite->op);
+
+        // Find all the vars that are live at the callsite. This is done by merging the
+        // in and out varset's and then removing the var that references the primitive
+        // function itself since we don't want this included in the calculation.
+        const transform::ControlFlowGraph::NodePtr cfg_node =
+            control_flow_graph_.let_map.at(GetRef<Let>(pre_let_node));
+        transform::VarSet live_tensors = liveness_.live_in.at(cfg_node);
+        const transform::VarSet& live_out = liveness_.live_out.at(cfg_node);
+        live_tensors.insert(live_out.begin(), live_out.end());
+        live_tensors.erase(call_op);
+
+        // Calculate size of live tensors and store to allow annotation when the function
+        // gets visited.
+        uint64_t used_memory = 0;
+        for (const auto& var : live_tensors) {
+          Type type = var->checked_type();
+          ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+          used_memory += CalculateRelayExprSizeBytes(type);

Review Comment:
   sg!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] manupa-arm commented on a diff in pull request #11208: [AOT] Calculate used memory at the callsite of primitive functions

Posted by GitBox <gi...@apache.org>.
manupa-arm commented on code in PR #11208:
URL: https://github.com/apache/tvm/pull/11208#discussion_r883472217


##########
src/relay/backend/annotate_used_memory.cc:
##########
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/annotate_used_memory.cc
+ * \brief Analyzes the used memory at the callsite of primitive functions.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/transform.h>
+
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../transforms/device_aware_visitors.h"
+#include "./liveness_analysis.h"
+#include "./utils.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+
+/*!
+ * \brief Annotates the memory usage of each primitive function by analyzing the liveness
+ * of the input/output tensors at each function callsite and calculating the total amount of
+ * memory these tensors require. This is added as a "used_memory" annotation to the function
+ * in question. In addition, the containing function is annotated with an "io_used_memory"
+ * annotation which refers to the total memory required for the IO tensors.
+ *
+ * A simple example:
+ *
+ * Before:
+ * def @main(%input: Tensor[(1, 2, 2, 4), int8]) -> Tensor[(1, 2, 2, 4), int8] {
+ *   let %x_0 = fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1) -> Tensor[(1, 2, 2, 4), int8] {
+ *     nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0])
+ *   };
+ *   let %x_1 = %x_0(%input);
+ *   %x_1
+ * }
+ *
+ * After:
+ * def @main(%input: Tensor[(1, 2, 2, 4), int8], io_used_memory=32) -> Tensor[(1, 2, 2, 4), int8] {
+ *   let %x_0: fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1, used_memory=32) -> Tensor[(1, 2, 2,
+ * 4), int8] { nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0])
+ *   };
+ *   let %x_1: Tensor[(1, 2, 2, 4), int8] = %x_0(%input);
+ *   %x_1
+ * }
+ *
+ * Note that in the simple example above io_used_memory and used_memory are the same since there
+ * is only one primitive function.
+ */
+class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator {
+ public:
+  AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg,
+                            const transform::LivenessAnalysis& lva)
+      : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {}
+
+  /*!
+   * \brief Mutates the input function. In addition, an "io_used_memory" annotation is
+   * added to the input function which refers to the total size required for the IO
+   * tensors.
+   */
+  Function operator()(const Function& func) {
+    uint64_t io_used_memory = 0;
+
+    // Inputs
+    for (const Var& param : func->params) {
+      Type type = param->checked_type();
+      ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+      io_used_memory += CalculateRelayExprSizeBytes(type);
+    }
+
+    // Outputs
+    Type type = func->body->checked_type();
+    ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+    io_used_memory += CalculateRelayExprSizeBytes(type);
+
+    Expr new_func_body = VisitExpr(func->body);
+    Function new_func = WithFields(func, func->params, new_func_body);
+    return WithAttr(std::move(new_func), "io_used_memory",
+                    tvm::IntImm(tvm::DataType::UInt(64), io_used_memory));
+  }
+
+  /*!
+   * \brief Establish which let bindings have primitive function values.
+   */
+  std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value) {
+    if (const auto* func_node = value.as<FunctionNode>()) {
+      ICHECK(func_node->attrs.HasNonzeroAttr(attr::kPrimitive))
+          << "Expect top-level functions to be primitive.";
+      let_bound_prim_func_.insert(var);
+    }
+    return DeviceAwareExprMutator::PreVisitLetBinding_(var, value);
+  }
+
+  /*!
+   * \brief Visit let nodes and perform one of two actions depending on their value:
+   *
+   * 1. CallNode - Calculate "used_memory" annotation value at the callsite of
+   *               primitive functions.
+   *
+   * 2. FunctionNode - Annotate functions with "used_memory" annotation based on the
+   *                   previous analysis at the callsite.
+   *
+   */
+  Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node) override {
+    Var let_var = post_let_node->var;
+    Expr let_value = IgnoreOnDevice(post_let_node->value);
+
+    if (let_value->IsInstance<CallNode>()) {
+      Call callsite = Downcast<Call>(let_value);
+      if (CheckPrimitiveFunctionCall(callsite)) {
+        Var call_op = Downcast<Var>(callsite->op);
+
+        // Find all the vars that are live at the callsite. This is done by merging the
+        // in and out varset's and then removing the var that references the primitive
+        // function itself since we don't want this included in the calculation.
+        const transform::ControlFlowGraph::NodePtr cfg_node =
+            control_flow_graph_.let_map.at(GetRef<Let>(pre_let_node));
+        transform::VarSet live_tensors = liveness_.live_in.at(cfg_node);
+        const transform::VarSet& live_out = liveness_.live_out.at(cfg_node);
+        live_tensors.insert(live_out.begin(), live_out.end());
+        live_tensors.erase(call_op);
+
+        // Calculate size of live tensors and store to allow annotation when the function
+        // gets visited.
+        uint64_t used_memory = 0;
+        for (const auto& var : live_tensors) {
+          Type type = var->checked_type();
+          ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+          used_memory += CalculateRelayExprSizeBytes(type);
+        }
+        used_memory_annotations_[call_op] = used_memory;

Review Comment:
   Thinking about it bit more -- I think we need the max(used memory at each of calls to a given function). See my comment above as well.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] lhutton1 commented on a diff in pull request #11208: [AOT] Calculate used memory at the callsite of primitive functions

Posted by GitBox <gi...@apache.org>.
lhutton1 commented on code in PR #11208:
URL: https://github.com/apache/tvm/pull/11208#discussion_r885076297


##########
src/relay/backend/annotate_used_memory.cc:
##########
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/annotate_used_memory.cc
+ * \brief Analyzes the used memory at the callsite of primitive functions.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/transform.h>
+
+#include <unordered_map>
+#include <unordered_set>
+
+#include "../transforms/device_aware_visitors.h"
+#include "./liveness_analysis.h"
+#include "./utils.h"
+
+namespace tvm {
+namespace relay {
+namespace backend {
+
+/*!
+ * \brief Annotates the memory usage of each primitive function by analyzing the liveness
+ * of the input/output tensors at each function callsite and calculating the total amount of
+ * memory these tensors require. This is added as a "used_memory" annotation to the function
+ * in question. In addition, the containing function is annotated with an "io_used_memory"
+ * annotation which refers to the total memory required for the IO tensors.
+ *
+ * A simple example:
+ *
+ * Before:
+ * def @main(%input: Tensor[(1, 2, 2, 4), int8]) -> Tensor[(1, 2, 2, 4), int8] {
+ *   let %x_0 = fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1) -> Tensor[(1, 2, 2, 4), int8] {
+ *     nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0])
+ *   };
+ *   let %x_1 = %x_0(%input);
+ *   %x_1
+ * }
+ *
+ * After:
+ * def @main(%input: Tensor[(1, 2, 2, 4), int8], io_used_memory=32) -> Tensor[(1, 2, 2, 4), int8] {
+ *   let %x_0: fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1, used_memory=32) -> Tensor[(1, 2, 2,
+ * 4), int8] { nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0])
+ *   };
+ *   let %x_1: Tensor[(1, 2, 2, 4), int8] = %x_0(%input);
+ *   %x_1
+ * }
+ *
+ * Note that in the simple example above io_used_memory and used_memory are the same since there
+ * is only one primitive function.
+ */
+class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator {
+ public:
+  AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg,
+                            const transform::LivenessAnalysis& lva)
+      : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {}
+
+  /*!
+   * \brief Mutates the input function. In addition, an "io_used_memory" annotation is
+   * added to the input function which refers to the total size required for the IO
+   * tensors.
+   */
+  Function operator()(const Function& func) {
+    uint64_t io_used_memory = 0;
+
+    // Inputs
+    for (const Var& param : func->params) {
+      Type type = param->checked_type();
+      ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+      io_used_memory += CalculateRelayExprSizeBytes(type);
+    }
+
+    // Outputs
+    Type type = func->body->checked_type();
+    ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+    io_used_memory += CalculateRelayExprSizeBytes(type);
+
+    Expr new_func_body = VisitExpr(func->body);
+    Function new_func = WithFields(func, func->params, new_func_body);
+    return WithAttr(std::move(new_func), "io_used_memory",
+                    tvm::IntImm(tvm::DataType::UInt(64), io_used_memory));
+  }
+
+  /*!
+   * \brief Establish which let bindings have primitive function values.
+   */
+  std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value) {
+    if (const auto* func_node = value.as<FunctionNode>()) {
+      ICHECK(func_node->attrs.HasNonzeroAttr(attr::kPrimitive))
+          << "Expect top-level functions to be primitive.";
+      let_bound_prim_func_.insert(var);
+    }
+    return DeviceAwareExprMutator::PreVisitLetBinding_(var, value);
+  }
+
+  /*!
+   * \brief Visit let nodes and perform one of two actions depending on their value:
+   *
+   * 1. CallNode - Calculate "used_memory" annotation value at the callsite of
+   *               primitive functions.
+   *
+   * 2. FunctionNode - Annotate functions with "used_memory" annotation based on the
+   *                   previous analysis at the callsite.
+   *
+   */
+  Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node) override {
+    Var let_var = post_let_node->var;
+    Expr let_value = IgnoreOnDevice(post_let_node->value);
+
+    if (let_value->IsInstance<CallNode>()) {
+      Call callsite = Downcast<Call>(let_value);
+      if (CheckPrimitiveFunctionCall(callsite)) {
+        Var call_op = Downcast<Var>(callsite->op);
+
+        // Find all the vars that are live at the callsite. This is done by merging the
+        // in and out varset's and then removing the var that references the primitive
+        // function itself since we don't want this included in the calculation.
+        const transform::ControlFlowGraph::NodePtr cfg_node =
+            control_flow_graph_.let_map.at(GetRef<Let>(pre_let_node));
+        transform::VarSet live_tensors = liveness_.live_in.at(cfg_node);
+        const transform::VarSet& live_out = liveness_.live_out.at(cfg_node);
+        live_tensors.insert(live_out.begin(), live_out.end());
+        live_tensors.erase(call_op);
+
+        // Calculate size of live tensors and store to allow annotation when the function
+        // gets visited.
+        uint64_t used_memory = 0;
+        for (const auto& var : live_tensors) {
+          Type type = var->checked_type();
+          ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
+          used_memory += CalculateRelayExprSizeBytes(type);
+        }
+        used_memory_annotations_[call_op] = used_memory;

Review Comment:
   Thanks for the discussion, yes a list sounds like a good idea so that information is not being lost (hopefully its useful to someone in the future)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] manupa-arm merged pull request #11208: [AOT] Calculate used memory at the callsite of primitive functions

Posted by GitBox <gi...@apache.org>.
manupa-arm merged PR #11208:
URL: https://github.com/apache/tvm/pull/11208


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] zhaoyang-star commented on pull request #11208: [AOT] Calculate used memory at the callsite of primitive functions

Posted by GitBox <gi...@apache.org>.
zhaoyang-star commented on PR #11208:
URL: https://github.com/apache/tvm/pull/11208#issuecomment-1276903736

   Hi @lhutton1 , thanks for your contributition.
   After running FuseOps pass, I want to get the memory usage per op or per primitive func  by `AnnotateUsedMemory` pass for furture optimization. I get a resnet18 ir model, then put it as the input IRModule of `AnnotateUsedMemory` pass. The output IRModule has no `used_memory` attr. Test code as follow:
   ```python
   import pytest
   from collections import OrderedDict
   import numpy as np
   import tvm
   from tvm import relay
   from tvm.relay import testing
   
   
   def AnnotateUsedMemory():
       return relay.transform._ffi_api.AnnotateUsedMemory()
   
   
   def _get_data(in_data_shapes, dtype="float32"):
       in_data = OrderedDict()
       for name, shape in in_data_shapes.items():
           in_data[name] = np.random.uniform(size=shape).astype(dtype)
       return in_data
   
   
   def _run_relay(mod, params, in_data, pass_enabled):
       target = "llvm"
       dev = tvm.device("llvm", 0)
       in_data = [tvm.nd.array(value) for value in in_data.values()]
   
       if pass_enabled:
           mod = relay.transform.InferType()(mod)
           mod = relay.transform.ToANormalForm()(mod)
           mod = relay.transform.InferType()(mod)
           mod = AnnotateUsedMemory()(mod)
           # create primitive functions
           mod = relay.transform.FuseOps()(mod)
   
       print(f'\nmod when AnnotateUsedMemory is {pass_enabled}:\n {mod}')
   
       out_data = relay.create_executor(
           "graph", mod, device=dev, target=target).evaluate()(*in_data, **params)
       return out_data.numpy()
   
   
   def _verify_results(mod, params, in_data, rtol=1e-5, atol=1e-5):
       before = _run_relay(mod, params, in_data, False)
       after = _run_relay(mod, params, in_data, True)
       np.testing.assert_allclose(before, after, rtol, atol)
   
   
   def test_resnet():
       num_class = 1000
       in_data_shapes = OrderedDict({"data": (1, 3, 224, 224)})
       in_data = _get_data(in_data_shapes, dtype="float32")
       for n in [18]:  # 18, 34, 50, 101
           mod, params = tvm.relay.testing.resnet.get_workload(
               batch_size=1, num_classes=num_class, num_layers=n)
           _verify_results(mod, params, in_data)
   
   
   if __name__ == "__main__":
       pytest.main([__file__])
   ```
   
   I am not familar with `AnnotateUsedMemory` pass. Could memory usage per op or per primitive func be gotten by your pass? If not, how to get it based on your pass? Thanks in advance ^_^


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] lhutton1 commented on pull request #11208: [AOT] Calculate used memory at the callsite of primitive functions

Posted by GitBox <gi...@apache.org>.
lhutton1 commented on PR #11208:
URL: https://github.com/apache/tvm/pull/11208#issuecomment-1280579669

   Hi @zhaoyang-star, yes I was able to reproduce the issue with your script. The script I have would be the same as yours just with the a different pass order as mentioned above. Placing `FuseOps` before `AnnotateUsedMemory` seems like the correct thing to do here; if you print out the module (`mod`) after the `AnnotateUsedMemory` pass you should be able to see the `used_memory` annotations. The `Check failed: (tensor_type) is false:` error comes later in the compilation so it seems as though some later optimization passes cannot deal with ANF yet.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org